Skip to content

Commit efd9362

Browse files
committed
added dump, removed cnn_kwargs
1 parent e53a79c commit efd9362

File tree

5 files changed

+197
-17
lines changed

5 files changed

+197
-17
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,4 @@ tests_logs
179179
tests/logs
180180
runs/
181181
vector_db*
182+
/wandb

autointent/modules/scoring/_cnn/cnn.py

+48-15
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import numpy.typing as npt
99
import torch
10-
from torch import nn, Tensor
10+
from torch import nn
1111
from torch.utils.data import DataLoader, TensorDataset
1212

1313
from autointent import Context
@@ -16,31 +16,54 @@
1616
from autointent.modules.base import BaseScorer
1717
from autointent.modules.scoring._cnn.textcnn import TextCNN
1818

19-
2019
class CNNScorer(BaseScorer):
21-
"""Convolutional Neural Network (CNN) scorer for intent classification."""
20+
"""Convolutional Neural Network (CNN) scorer for intent classification.
21+
22+
Args:
23+
max_seq_length: Maximum length of input sequences.
24+
num_train_epochs: Number of training epochs.
25+
batch_size: Batch size for training.
26+
learning_rate: Learning rate for optimizer.
27+
seed: Random seed.
28+
report_to: Where to report training metrics.
29+
embed_dim: Dimension of word embeddings.
30+
kernel_sizes: Tuple of kernel sizes for convolutional layers.
31+
num_filters: Number of filters for each convolutional layer.
32+
dropout: Dropout rate.
33+
pretrained_embs: Pretrained embeddings tensor (optional).
34+
"""
2235

2336
name = "cnn"
2437
supports_multilabel = True
2538
supports_multiclass = True
2639

27-
def __init__(
40+
def __init__( # noqa: PLR0913
2841
self,
2942
max_seq_length: int = 50,
3043
num_train_epochs: int = 3,
3144
batch_size: int = 8,
3245
learning_rate: float = 5e-5,
3346
seed: int = 0,
34-
report_to: REPORTERS_NAMES | None = None, # type: ignore[no-any-return]
35-
**cnn_kwargs: dict[str, Any],
47+
report_to: REPORTERS_NAMES | None = None,
48+
embed_dim: int = 128,
49+
kernel_sizes: tuple[int, ...] = (3, 4, 5),
50+
num_filters: int = 100,
51+
dropout: float = 0.1,
52+
pretrained_embs: torch.Tensor | None = None,
3653
) -> None:
3754
self.max_seq_length = max_seq_length
3855
self.num_train_epochs = num_train_epochs
3956
self.batch_size = batch_size
4057
self.learning_rate = learning_rate
4158
self.seed = seed
4259
self.report_to = report_to
43-
self.cnn_config = cnn_kwargs
60+
61+
# CNN-specific parameters
62+
self.embed_dim = embed_dim
63+
self.kernel_sizes = kernel_sizes
64+
self.num_filters = num_filters
65+
self.dropout = dropout
66+
self.pretrained_embs = pretrained_embs
4467

4568
# Will be initialized during fit()
4669
self._model: TextCNN | None = None
@@ -53,22 +76,32 @@ def __init__(
5376
self._multilabel: bool = False
5477

5578
@classmethod
56-
def from_context(
79+
def from_context( # noqa: PLR0913
5780
cls,
5881
context: Context,
82+
max_seq_length: int = 50,
5983
num_train_epochs: int = 3,
6084
batch_size: int = 8,
6185
learning_rate: float = 5e-5,
6286
seed: int = 0,
63-
**cnn_kwargs: dict[str, Any],
87+
embed_dim: int = 128,
88+
kernel_sizes: tuple[int, ...] = (3, 4, 5),
89+
num_filters: int = 100,
90+
dropout: float = 0.1,
91+
pretrained_embs: torch.Tensor | None = None,
6492
) -> "CNNScorer":
6593
return cls(
94+
max_seq_length=max_seq_length,
6695
num_train_epochs=num_train_epochs,
6796
batch_size=batch_size,
6897
learning_rate=learning_rate,
6998
seed=seed,
7099
report_to=context.logging_config.report_to,
71-
**cnn_kwargs,
100+
embed_dim=embed_dim,
101+
kernel_sizes=kernel_sizes,
102+
num_filters=num_filters,
103+
dropout=dropout,
104+
pretrained_embs=pretrained_embs,
72105
)
73106

74107
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
@@ -94,12 +127,12 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
94127
self._model = TextCNN(
95128
vocab_size=len(self._vocab),
96129
n_classes=self._n_classes,
97-
embed_dim=self.cnn_config.get("embed_dim", 128),
98-
kernel_sizes=self.cnn_config.get("kernel_sizes", (3, 4, 5)),
99-
num_filters=self.cnn_config.get("num_filters", 100),
100-
dropout=self.cnn_config.get("dropout", 0.1),
130+
embed_dim=self.embed_dim,
131+
kernel_sizes=self.kernel_sizes,
132+
num_filters=self.num_filters,
133+
dropout=self.dropout,
101134
padding_idx=self._padding_idx,
102-
pretrained_embs=self.cnn_config.get("pretrained_embs", None),
135+
pretrained_embs=self.pretrained_embs,
103136
)
104137

105138
# Training

autointent/modules/scoring/_cnn/textcnn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""TextCNN model for text classification."""
22

3-
import torch
4-
import torch.nn.functional as F
53
from torch import nn
4+
import torch
5+
import torch.nn.functional as F # noqa: N812
66

77

88
class TextCNN(nn.Module):

test_main.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from autointent.modules.scoring._cnn.cnn import CNNScorer
2+
3+
# Sample data
4+
utterances = [
5+
"I love programming",
6+
"I hate bugs",
7+
"Python is awesome",
8+
"Debugging is frustrating",
9+
"Machine learning is fun",
10+
"I dislike errors",
11+
]
12+
print(utterances)
13+
labels = [1, 0, 1, 0, 1, 0] # 1 = positive, 0 = negative
14+
15+
# Initialize the scorer
16+
scorer = CNNScorer()
17+
# Train the model
18+
print('before fit')
19+
scorer.fit(utterances, labels)
20+
21+
# Test set
22+
test_utterances = [
23+
"I enjoy coding",
24+
"I find bugs annoying",
25+
"AI is fascinating",
26+
"Errors are frustrating",
27+
]
28+
29+
# Predict probabilities
30+
probabilities = scorer.predict(test_utterances)
31+
print("Predicted Probabilities:")
32+
print(probabilities)
33+
34+
# Convert probabilities to predicted labels
35+
predicted_labels = (probabilities > 0.5).astype(int) # For binary classification
36+
print("Predicted Labels:")
37+
print(predicted_labels)
38+
39+
# Expected labels for the test set
40+
expected_labels = [1, 0, 1, 0]
41+
42+
# Compare predicted and expected labels
43+
for i, (pred, exp) in enumerate(zip(predicted_labels, expected_labels)):
44+
print(f"Test Utterance {i+1}: {test_utterances[i]}")
45+
print(f"Predicted: {pred}, Expected: {exp}")

tests/modules/scoring/test_cnn.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# import numpy as np
2+
# import pytest
3+
4+
# from autointent.context.data_handler import DataHandler
5+
# from autointent.modules.scoring._cnn import CNNScorer
6+
7+
8+
# def test_cnn_prediction(dataset):
9+
# """Test that the CNN model can fit and make predictions."""
10+
# data_handler = DataHandler(dataset)
11+
12+
# scorer = CNNScorer(
13+
# max_seq_length=50,
14+
# num_train_epochs=1,
15+
# batch_size=8,
16+
# learning_rate=5e-5,
17+
# embed_dim=128,
18+
# kernel_sizes=(3, 4, 5),
19+
# num_filters=100,
20+
# dropout=0.1
21+
# )
22+
# scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
23+
24+
# test_data = [
25+
# "why is there a hold on my account",
26+
# "i am not sure why my account is blocked",
27+
# "why is there a hold on my checking account",
28+
# "i think my account is blocked",
29+
# "can you tell me why is my account frozen",
30+
# ]
31+
32+
# predictions = scorer.predict(test_data)
33+
34+
# assert predictions.shape[0] == len(test_data)
35+
# assert predictions.shape[1] == len(set(data_handler.train_labels(0)))
36+
37+
# # Проверяем что предсказания в диапазоне [0, 1]
38+
# assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0
39+
40+
# # Для мультиклассовой классификации сумма предсказаний должна быть ~1.0
41+
# if not scorer._multilabel:
42+
# for pred_row in predictions:
43+
# np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)
44+
45+
# # Проверяем работу predict_with_metadata если метод существует
46+
# if hasattr(scorer, "predict_with_metadata"):
47+
# predictions, metadata = scorer.predict_with_metadata(test_data)
48+
# assert len(predictions) == len(test_data)
49+
# assert metadata is None
50+
51+
52+
# def test_cnn_cache_clearing(dataset):
53+
# """Test that the CNN model properly handles cache clearing."""
54+
# data_handler = DataHandler(dataset)
55+
56+
# scorer = CNNScorer(
57+
# max_seq_length=50,
58+
# num_train_epochs=1,
59+
# batch_size=8,
60+
# learning_rate=5e-5
61+
# )
62+
# scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
63+
64+
# test_data = ["test text"]
65+
66+
# # Первое предсказание
67+
# scorer.predict(test_data)
68+
69+
# # Очистка кэша
70+
# scorer.clear_cache()
71+
72+
# # Проверяем что модель очищена
73+
# assert not hasattr(scorer, "_model") or scorer._model is None
74+
# assert not hasattr(scorer, "_vocab") or scorer._vocab is None
75+
76+
# # После очистки кэша предсказания должны вызывать ошибку
77+
# with pytest.raises(ValueError, match="Model not trained. Call fit() first."):
78+
# scorer.predict(test_data)
79+
80+
81+
# def test_cnn_multilabel(dataset_multilabel):
82+
# """Test CNN scorer with multilabel data."""
83+
# data_handler = DataHandler(dataset_multilabel)
84+
85+
# scorer = CNNScorer(
86+
# max_seq_length=50,
87+
# num_train_epochs=1,
88+
# batch_size=8,
89+
# learning_rate=5e-5
90+
# )
91+
# scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
92+
93+
# test_data = ["sample text for testing", "another test example"]
94+
# predictions = scorer.predict(test_data)
95+
96+
# # Для multilabel проверяем что выходные вероятности независимы
97+
# assert predictions.shape[0] == len(test_data)
98+
# assert predictions.shape[1] == len(data_handler.train_labels(0)[0])
99+
100+
# # Проверяем что есть предсказания не только 0 и 1
101+
# assert np.any((predictions > 0) & (predictions < 1))

0 commit comments

Comments
 (0)