Skip to content

Commit 682640e

Browse files
committed
improve learning of new classes
1 parent 5398fc8 commit 682640e

File tree

7 files changed

+917
-356
lines changed

7 files changed

+917
-356
lines changed

examples/advanced_usage.py

+144-26
Original file line numberDiff line numberDiff line change
@@ -24,44 +24,126 @@ def demonstrate_batch_processing():
2424
logger.info("Demonstrating batch processing...")
2525

2626
# Initialize classifier
27-
classifier = AdaptiveClassifier("bert-base-uncased")
27+
classifier = AdaptiveClassifier("distilbert/distilbert-base-cased")
2828

2929
# Create a larger dataset
3030
texts = []
3131
labels = []
3232

3333
# Simulate customer feedback dataset
3434
feedback_data = [
35+
# Positive feedback
3536
("The product is amazing!", "positive"),
37+
("Exceeded all my expectations, truly worth every penny", "positive"),
38+
("Customer service was incredibly helpful and responsive", "positive"),
39+
("Best purchase I've made this year", "positive"),
40+
("The quality is outstanding", "positive"),
41+
("Shipping was super fast and packaging was perfect", "positive"),
42+
("Really impressed with the durability", "positive"),
43+
("Great value for money", "positive"),
44+
("The features are exactly what I needed", "positive"),
45+
("Easy to use and very intuitive", "positive"),
46+
("Fantastic product, will definitely buy again", "positive"),
47+
("Love how lightweight and portable it is", "positive"),
48+
("The installation process was seamless", "positive"),
49+
("Brilliant design and functionality", "positive"),
50+
("Top-notch quality and performance", "positive"),
51+
52+
# Negative feedback
3653
("Worst experience ever", "negative"),
54+
("Product broke after just one week", "negative"),
55+
("Customer support never responded to my emails", "negative"),
56+
("Completely disappointed with the quality", "negative"),
57+
("Not worth the money at all", "negative"),
58+
("Arrived damaged and return process was horrible", "negative"),
59+
("The instructions were impossible to follow", "negative"),
60+
("Poor build quality, feels cheap", "negative"),
61+
("Missing essential features that were advertised", "negative"),
62+
("Terrible battery life", "negative"),
63+
("Keeps malfunctioning randomly", "negative"),
64+
("The worst customer service I've ever experienced", "negative"),
65+
("Save your money and avoid this product", "negative"),
66+
("Doesn't work as advertised", "negative"),
67+
("Had to return it immediately", "negative"),
68+
69+
# Neutral feedback
3770
("It works as expected", "neutral"),
38-
# Add more examples...
71+
("Average product, nothing special", "neutral"),
72+
("Does the job, but could be better", "neutral"),
73+
("Reasonable price for what you get", "neutral"),
74+
("Some good features, some bad ones", "neutral"),
75+
("Pretty standard quality", "neutral"),
76+
("Not bad, not great", "neutral"),
77+
("Meets basic requirements", "neutral"),
78+
("Similar to other products in this category", "neutral"),
79+
("Acceptable performance for the price", "neutral"),
80+
("Middle-of-the-road quality", "neutral"),
81+
("Functions adequately", "neutral"),
82+
("Basic functionality works fine", "neutral"),
83+
("Got what I paid for", "neutral"),
84+
("Standard delivery time and service", "neutral"),
85+
86+
# Technical feedback
87+
("Getting error code 404 when trying to sync", "technical"),
88+
("App crashes after latest update", "technical"),
89+
("Can't connect to WiFi despite correct password", "technical"),
90+
("Battery drains even when device is off", "technical"),
91+
("Screen freezes during startup", "technical"),
92+
("Bluetooth pairing fails consistently", "technical"),
93+
("System shows unrecognized device error", "technical"),
94+
("Software keeps reverting to previous version", "technical"),
95+
("Memory full error after minimal usage", "technical"),
96+
("Device overheats during normal operation", "technical"),
97+
("USB port not recognizing connections", "technical"),
98+
("Network connectivity drops randomly", "technical"),
99+
("Authentication failed error on login", "technical"),
100+
("Sync process stuck at 99%", "technical"),
101+
("Database connection timeout error", "technical")
39102
]
40103

104+
# Number of times to replicate each example
105+
num_replications = 10 # This will create 10x more data
106+
41107
for text, label in feedback_data:
42-
texts.extend([text] * 10) # Replicate each example 10 times for demo
43-
labels.extend([label] * 10)
108+
# Add multiple copies of each example
109+
texts.extend([text] * num_replications)
110+
labels.extend([label] * num_replications)
111+
112+
logger.info(f"Total examples: {len(texts)}")
113+
logger.info(f"Examples per class: {sum(1 for l in labels if l == 'positive')}/{sum(1 for l in labels if l == 'negative')}/"
114+
f"{sum(1 for l in labels if l == 'neutral')}/{sum(1 for l in labels if l == 'technical')}")
44115

45116
# Create dataset and dataloader
46117
dataset = TextDataset(texts, labels)
47-
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
118+
batch_size = 8
119+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
120+
121+
# Calculate expected number of batches
122+
expected_batches = len(dataset) // batch_size + (1 if len(dataset) % batch_size != 0 else 0)
123+
logger.info(f"Expected number of batches: {expected_batches}")
48124

49125
# Process in batches
50126
start_time = time.time()
51127
for batch_idx, (batch_texts, batch_labels) in enumerate(dataloader):
52128
classifier.add_examples(batch_texts, batch_labels)
53-
if batch_idx % 10 == 0:
54-
logger.info(f"Processed batch {batch_idx}")
129+
if batch_idx % 5 == 0: # Log every 5 batches
130+
logger.info(f"Processed batch {batch_idx + 1}/{expected_batches}")
131+
132+
# Optional: print batch sizes to verify
133+
if batch_idx in [0, expected_batches // 2, expected_batches - 1]: # Print first, middle, and last batch
134+
logger.info(f"Batch {batch_idx + 1} size: {len(batch_texts)}")
55135

56-
logger.info(f"Processing time: {time.time() - start_time:.2f} seconds")
136+
processing_time = time.time() - start_time
137+
logger.info(f"Processing time: {processing_time:.2f} seconds")
138+
logger.info(f"Average time per batch: {processing_time/expected_batches:.2f} seconds")
57139

58140
return classifier
59141

60142
def demonstrate_continuous_learning():
61143
"""Example of continuous learning with performance monitoring"""
62144
logger.info("Demonstrating continuous learning...")
63145

64-
classifier = AdaptiveClassifier("bert-base-uncased")
146+
classifier = AdaptiveClassifier("distilbert/distilbert-base-cased")
65147

66148
# Initial training
67149
initial_texts = [
@@ -118,7 +200,7 @@ def evaluate_performance(test_texts: List[str], test_labels: List[str]) -> float
118200
def demonstrate_persistence():
119201
# 1. Create and train initial classifier
120202
print("Phase 1: Creating and training initial classifier")
121-
classifier = AdaptiveClassifier("bert-base-uncased")
203+
classifier = AdaptiveClassifier("distilbert/distilbert-base-cased")
122204

123205
# Add some initial examples
124206
initial_texts = [
@@ -170,38 +252,74 @@ def demonstrate_multi_language():
170252
logger.info("Demonstrating multi-language support...")
171253

172254
# Use a multilingual model
173-
classifier = AdaptiveClassifier("bert-base-multilingual-uncased")
255+
classifier = AdaptiveClassifier("distilbert/distilbert-base-multilingual-cased")
174256

175-
# Add examples in different languages
176257
texts = [
177-
# English
258+
# English - Positive
178259
"This is great",
260+
"I love this product",
261+
"Amazing experience",
262+
"Excellent service",
263+
"Best purchase ever",
264+
"Highly recommended",
265+
"Really impressive quality",
266+
"Fantastic results",
267+
268+
# English - Negative
179269
"This is terrible",
180-
# Spanish
270+
"Worst experience ever",
271+
"Don't waste your money",
272+
"Very disappointed",
273+
"Poor quality product",
274+
"Absolutely horrible",
275+
"Complete waste of time",
276+
"Not worth buying",
277+
278+
# Spanish - Positive
181279
"Esto es excelente",
280+
"Me encanta este producto",
281+
"Una experiencia maravillosa",
282+
"Servicio excepcional",
283+
"La mejor compra",
284+
"Muy recomendable",
285+
"Calidad impresionante",
286+
"Resultados fantásticos",
287+
288+
# Spanish - Negative
182289
"Esto es terrible",
183-
# French
184-
"C'est excellent",
185-
"C'est terrible"
290+
"La peor experiencia",
291+
"No malgastes tu dinero",
292+
"Muy decepcionado",
293+
"Producto de mala calidad",
294+
"Absolutamente horrible",
295+
"Pérdida total de tiempo",
296+
"No vale la pena comprarlo",
186297
]
187-
188-
labels = ["positive", "negative"] * 3
298+
299+
labels = ["positive"] * 8 + ["negative"] * 8 \
300+
+ ["positive"] * 8 + ["negative"] * 8
189301

190302
classifier.add_examples(texts, labels)
191303

192304
# Test in different languages
193305
test_texts = [
194-
"This is wonderful", # English
195-
"Esto es maravilloso", # Spanish
196-
"C'est merveilleux" # French
306+
# English
307+
"This is wonderful", # Positive
308+
"This is terrible", # Negative
309+
310+
# Spanish
311+
"Esto es maravilloso", # Positive
312+
"Esto es terrible", # Negative
197313
]
198-
314+
315+
# Print test results
316+
print("\nTesting predictions in multiple languages:")
199317
for text in test_texts:
200318
predictions = classifier.predict(text)
201-
logger.info(f"\nText: {text}")
202-
logger.info("Predictions:")
319+
print(f"\nText: {text}")
320+
print("Predictions:")
203321
for label, score in predictions:
204-
logger.info(f"{label}: {score:.4f}")
322+
print(f"{label}: {score:.4f}")
205323

206324
return classifier
207325

examples/basic_usage.py

+66-23
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,55 @@
1+
import torch
2+
import numpy as np
3+
import random
14
from adaptive_classifier import AdaptiveClassifier
25

36
def main():
7+
48
# Initialize classifier
5-
classifier = AdaptiveClassifier("bert-base-uncased")
9+
classifier = AdaptiveClassifier("distilbert/distilbert-base-cased")
610

7-
# Initial training data
11+
# Initial training data with atleast 5 examples per class
812
texts = [
13+
# Positive examples
914
"The product works great!",
1015
"Amazing service, very satisfied",
1116
"This exceeded my expectations",
17+
"Best purchase I've made this year",
18+
"Really impressed with the quality",
19+
"Fantastic product, will buy again",
20+
"Highly recommend this to everyone",
21+
22+
# Negative examples
1223
"Terrible experience, don't buy",
1324
"Worst product ever",
25+
"Complete waste of money",
26+
"Poor quality and bad service",
27+
"Would not recommend to anyone",
28+
"Disappointed with the purchase",
29+
"Product broke after first use",
30+
31+
# Neutral examples
1432
"Product arrived on time",
15-
"Does what it says"
33+
"Does what it says",
34+
"Average product, nothing special",
35+
"Meets basic requirements",
36+
"Fair price for what you get",
37+
"Standard quality product",
38+
"Works as expected"
1639
]
1740

1841
labels = [
42+
# Positive labels
43+
"positive", "positive", "positive", "positive",
1944
"positive", "positive", "positive",
20-
"negative", "negative",
21-
"neutral", "neutral"
45+
46+
# Negative labels
47+
"negative", "negative", "negative", "negative",
48+
"negative", "negative", "negative",
49+
50+
# Neutral labels
51+
"neutral", "neutral", "neutral", "neutral",
52+
"neutral", "neutral", "neutral"
2253
]
2354

2455
# Add examples
@@ -27,18 +58,21 @@ def main():
2758

2859
# Test predictions
2960
test_texts = [
30-
"This is fantastic!",
31-
"I hate this product",
32-
"It's okay, nothing special"
61+
"This is a fantastic product!",
62+
"Disappointed with this bad product",
63+
"Average product, as expected"
3364
]
3465

3566
print("\nTesting predictions:")
36-
for text in test_texts:
37-
predictions = classifier.predict(text)
38-
print(f"\nText: {text}")
39-
print("Predictions:")
40-
for label, score in predictions:
41-
print(f"{label}: {score:.4f}")
67+
classifier.model.eval()
68+
69+
with torch.no_grad():
70+
for text in test_texts:
71+
predictions = classifier.predict(text)
72+
print(f"\nText: {text}")
73+
print("Predictions:")
74+
for label, score in predictions:
75+
print(f"{label}: {score:.4f}")
4276

4377
# Save the classifier
4478
print("\nSaving classifier...")
@@ -48,24 +82,33 @@ def main():
4882
print("\nLoading classifier...")
4983
loaded_classifier = AdaptiveClassifier.load("./demo_classifier")
5084

51-
# Add new class
85+
# Add new technical class with more examples
5286
print("\nAdding new technical class...")
5387
technical_texts = [
5488
"Error code 404 appeared",
55-
"System crashed after update"
89+
"System crashed after update",
90+
"Cannot connect to database",
91+
"Memory allocation failed",
92+
"Null pointer exception detected",
93+
"API endpoint not responding",
94+
"Stack overflow in main thread"
5695
]
57-
technical_labels = ["technical"] * 2
96+
technical_labels = ["technical"] * len(technical_texts)
5897

5998
loaded_classifier.add_examples(technical_texts, technical_labels)
6099

61100
# Test new predictions
62101
print("\nTesting technical classification:")
63-
technical_test = "Getting null pointer exception"
64-
predictions = loaded_classifier.predict(technical_test)
65-
print(f"\nText: {technical_test}")
66-
print("Predictions:")
67-
for label, score in predictions:
68-
print(f"{label}: {score:.4f}")
102+
technical_test = "API giving null pointer exception"
103+
104+
loaded_classifier.model.eval()
105+
106+
with torch.no_grad():
107+
predictions = loaded_classifier.predict(technical_test)
108+
print(f"\nText: {technical_test}")
109+
print("Predictions:")
110+
for label, score in predictions:
111+
print(f"{label}: {score:.4f}")
69112

70113
if __name__ == "__main__":
71114
main()

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ faiss-cpu>=1.7.4 # Use faiss-gpu for GPU support
55
numpy>=1.24.0
66
tqdm>=4.65.0
77
setuptools>=65.0.0
8-
wheel>=0.40.0
8+
wheel>=0.40.0
9+
scikit-learn

0 commit comments

Comments
 (0)