7
7
import numpy as np
8
8
import numpy .typing as npt
9
9
import torch
10
- from torch import nn , Tensor
10
+ from torch import nn
11
11
from torch .utils .data import DataLoader , TensorDataset
12
12
13
13
from autointent import Context
16
16
from autointent .modules .base import BaseScorer
17
17
from autointent .modules .scoring ._cnn .textcnn import TextCNN
18
18
19
-
20
19
class CNNScorer (BaseScorer ):
21
- """Convolutional Neural Network (CNN) scorer for intent classification."""
20
+ """Convolutional Neural Network (CNN) scorer for intent classification.
21
+
22
+ Args:
23
+ max_seq_length: Maximum length of input sequences.
24
+ num_train_epochs: Number of training epochs.
25
+ batch_size: Batch size for training.
26
+ learning_rate: Learning rate for optimizer.
27
+ seed: Random seed.
28
+ report_to: Where to report training metrics.
29
+ embed_dim: Dimension of word embeddings.
30
+ kernel_sizes: Tuple of kernel sizes for convolutional layers.
31
+ num_filters: Number of filters for each convolutional layer.
32
+ dropout: Dropout rate.
33
+ pretrained_embs: Pretrained embeddings tensor (optional).
34
+ """
22
35
23
36
name = "cnn"
24
37
supports_multilabel = True
25
38
supports_multiclass = True
26
39
27
- def __init__ (
40
+ def __init__ ( # noqa: PLR0913
28
41
self ,
29
42
max_seq_length : int = 50 ,
30
43
num_train_epochs : int = 3 ,
31
44
batch_size : int = 8 ,
32
45
learning_rate : float = 5e-5 ,
33
46
seed : int = 0 ,
34
- report_to : REPORTERS_NAMES | None = None , # type: ignore[no-any-return]
35
- ** cnn_kwargs : dict [str , Any ],
47
+ report_to : REPORTERS_NAMES | None = None ,
48
+ embed_dim : int = 128 ,
49
+ kernel_sizes : tuple [int , ...] = (3 , 4 , 5 ),
50
+ num_filters : int = 100 ,
51
+ dropout : float = 0.1 ,
52
+ pretrained_embs : torch .Tensor | None = None ,
36
53
) -> None :
37
54
self .max_seq_length = max_seq_length
38
55
self .num_train_epochs = num_train_epochs
39
56
self .batch_size = batch_size
40
57
self .learning_rate = learning_rate
41
58
self .seed = seed
42
59
self .report_to = report_to
43
- self .cnn_config = cnn_kwargs
60
+
61
+ # CNN-specific parameters
62
+ self .embed_dim = embed_dim
63
+ self .kernel_sizes = kernel_sizes
64
+ self .num_filters = num_filters
65
+ self .dropout = dropout
66
+ self .pretrained_embs = pretrained_embs
44
67
45
68
# Will be initialized during fit()
46
69
self ._model : TextCNN | None = None
@@ -53,22 +76,32 @@ def __init__(
53
76
self ._multilabel : bool = False
54
77
55
78
@classmethod
56
- def from_context (
79
+ def from_context ( # noqa: PLR0913
57
80
cls ,
58
81
context : Context ,
82
+ max_seq_length : int = 50 ,
59
83
num_train_epochs : int = 3 ,
60
84
batch_size : int = 8 ,
61
85
learning_rate : float = 5e-5 ,
62
86
seed : int = 0 ,
63
- ** cnn_kwargs : dict [str , Any ],
87
+ embed_dim : int = 128 ,
88
+ kernel_sizes : tuple [int , ...] = (3 , 4 , 5 ),
89
+ num_filters : int = 100 ,
90
+ dropout : float = 0.1 ,
91
+ pretrained_embs : torch .Tensor | None = None ,
64
92
) -> "CNNScorer" :
65
93
return cls (
94
+ max_seq_length = max_seq_length ,
66
95
num_train_epochs = num_train_epochs ,
67
96
batch_size = batch_size ,
68
97
learning_rate = learning_rate ,
69
98
seed = seed ,
70
99
report_to = context .logging_config .report_to ,
71
- ** cnn_kwargs ,
100
+ embed_dim = embed_dim ,
101
+ kernel_sizes = kernel_sizes ,
102
+ num_filters = num_filters ,
103
+ dropout = dropout ,
104
+ pretrained_embs = pretrained_embs ,
72
105
)
73
106
74
107
def fit (self , utterances : list [str ], labels : ListOfLabels ) -> None :
@@ -94,12 +127,12 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
94
127
self ._model = TextCNN (
95
128
vocab_size = len (self ._vocab ),
96
129
n_classes = self ._n_classes ,
97
- embed_dim = self .cnn_config . get ( " embed_dim" , 128 ) ,
98
- kernel_sizes = self .cnn_config . get ( " kernel_sizes" , ( 3 , 4 , 5 )) ,
99
- num_filters = self .cnn_config . get ( " num_filters" , 100 ) ,
100
- dropout = self .cnn_config . get ( " dropout" , 0.1 ) ,
130
+ embed_dim = self .embed_dim ,
131
+ kernel_sizes = self .kernel_sizes ,
132
+ num_filters = self .num_filters ,
133
+ dropout = self .dropout ,
101
134
padding_idx = self ._padding_idx ,
102
- pretrained_embs = self .cnn_config . get ( " pretrained_embs" , None ) ,
135
+ pretrained_embs = self .pretrained_embs ,
103
136
)
104
137
105
138
# Training
0 commit comments