33from typing import Any , Dict , Iterable , Optional , Sequence , Set , Union
44
55import torch
6+ import torch .nn as nn
67from spacy .tokens import Doc
7- from typing_extensions import NotRequired , TypedDict
8+ from typing_extensions import Literal , NotRequired , TypedDict
89
910from edsnlp .core .pipeline import PipelineProtocol
1011from edsnlp .core .torch_component import BatchInput , TorchComponent
@@ -36,6 +37,8 @@ class TrainableDocClassifier(
3637 TorchComponent [DocClassifierBatchOutput , DocClassifierBatchInput ],
3738 BaseComponent ,
3839):
40+ """A trainable document classifier that uses embeddings to classify documents."""
41+
3942 def __init__ (
4043 self ,
4144 nlp : Optional [PipelineProtocol ] = None ,
@@ -49,12 +52,21 @@ def __init__(
4952 loss_fn = None ,
5053 labels : Optional [Sequence [str ]] = None ,
5154 class_weights : Optional [Union [Dict [str , float ], str ]] = None ,
55+ hidden_size : Optional [int ] = None ,
56+ activation_mode : Literal ["relu" , "gelu" , "silu" ] = "relu" ,
57+ dropout_rate : Optional [float ] = 0.0 ,
58+ layer_norm : Optional [bool ] = False ,
5259 ):
60+ self .num_classes = num_classes
5361 self .label_attr : Attributes = label_attr
5462 self .label2id = label2id or {}
5563 self .id2label = id2label or {}
5664 self .labels = labels
5765 self .class_weights = class_weights
66+ self .hidden_size = hidden_size
67+ self .activation_mode = activation_mode
68+ self .dropout_rate = dropout_rate
69+ self .layer_norm = layer_norm
5870
5971 super ().__init__ (nlp , name )
6072 self .embedding = embedding
@@ -66,9 +78,23 @@ def __init__(
6678 raise ValueError (
6779 "The embedding component must have an 'output_size' attribute."
6880 )
69- embedding_size = self .embedding .output_size
70- if num_classes :
71- self .classifier = torch .nn .Linear (embedding_size , num_classes )
81+ self .embedding_size = self .embedding .output_size
82+ if self .num_classes :
83+ self .build_classifier ()
84+
85+ def build_classifier (self ):
86+ """Build classification head"""
87+ if self .hidden_size :
88+ self .hidden_layer = torch .nn .Linear (self .embedding_size , self .hidden_size )
89+ self .activation = {"relu" : nn .ReLU (), "gelu" : nn .GELU (), "silu" : nn .SiLU ()}[
90+ self .activation_mode
91+ ]
92+ if self .layer_norm :
93+ self .norm = nn .LayerNorm (self .hidden_size )
94+ self .dropout = nn .Dropout (self .dropout_rate )
95+ self .classifier = torch .nn .Linear (self .hidden_size , self .num_classes )
96+ else :
97+ self .classifier = torch .nn .Linear (self .embedding_size , self .num_classes )
7298
7399 def _compute_class_weights (self , freq_dict : Dict [str , int ]) -> torch .Tensor :
74100 """
@@ -112,10 +138,9 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
112138 for i , label in enumerate (labels ):
113139 self .label2id [label ] = i
114140 self .id2label [i ] = label
115- print ("num classes:" , len (self .label2id ))
116- self .classifier = torch .nn .Linear (
117- self .embedding .output_size , len (self .label2id )
118- )
141+ self .num_classes = len (self .label2id )
142+ print ("num classes:" , self .num_classes )
143+ self .build_classifier ()
119144
120145 weight_tensor = None
121146 if self .class_weights is not None :
@@ -138,6 +163,7 @@ def preprocess(self, doc: Doc) -> Dict[str, Any]:
138163 return {"embedding" : self .embedding .preprocess (doc )}
139164
140165 def preprocess_supervised (self , doc : Doc ) -> Dict [str , Any ]:
166+ """Preprocess document with target labels for training."""
141167 preps = self .preprocess (doc )
142168 label = getattr (doc ._ , self .label_attr , None )
143169 if label is None :
@@ -166,9 +192,14 @@ def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
166192 if targets provided.
167193 """
168194 pooled = self .embedding (batch ["embedding" ])
169- embeddings = pooled ["embeddings" ]
170-
171- logits = self .classifier (embeddings )
195+ x = pooled ["embeddings" ]
196+ if self .hidden_size :
197+ x = self .hidden_layer (x )
198+ x = self .activation (x )
199+ if self .layer_norm :
200+ x = self .norm (x )
201+ x = self .dropout (x )
202+ logits = self .classifier (x )
172203
173204 output : DocClassifierBatchOutput = {}
174205 if "targets" in batch :
@@ -181,6 +212,7 @@ def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
181212 return output
182213
183214 def postprocess (self , docs , results , input ):
215+ """Postprocess predictions by assigning labels to documents."""
184216 labels = results ["labels" ]
185217 if isinstance (labels , torch .Tensor ):
186218 labels = labels .tolist ()
0 commit comments