diff --git a/changelog.md b/changelog.md index 92b294b1ae..8650169a6e 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,12 @@ # Changelog +## Unreleased + +### Added + +- New `attention` pooling mode in `eds.span_pooler` +- New `word_pooling_mode=False` in `eds.transformer` to allow returning the worpiece embeddings directly, instead of the mean-pooled word embeddings. At the moment, this only works with `eds.span_pooler` which can pool over wordpieces or words seamlessly. + ## v0.18.0 (2025-09-02) 📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.19.0), in October 2025. Please upgrade to Python 3.10 or later. @@ -13,6 +20,7 @@ - New `eds.explode` pipe that splits one document into multiple documents, one per span yielded by its `span_getter` parameter, each new document containing exactly that single span. - New `Training a span classifier` tutorial, and reorganized deep-learning docs - `ScheduledOptimizer` now warns when a parameter selector does not match any parameter. +- New `attention` pooling mode in `eds.span_pooler` ### Fixed diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index 9f563e8fce..9d20e11042 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -4,6 +4,7 @@ We provide step-by-step guides to get you started. We cover the following use-ca ### Base tutorials + === card {: href=/tutorials/spacy101 } @@ -85,6 +86,8 @@ We provide step-by-step guides to get you started. We cover the following use-ca --- Quickly visualize the results of your pipeline as annotations or tables. + + ### Deep learning tutorials We also provide tutorials on how to train deep-learning models with EDS-NLP. These tutorials cover the training API, hyperparameter tuning, and more. @@ -123,8 +126,5 @@ We also provide tutorials on how to train deep-learning models with EDS-NLP. The --- Learn how to tune hyperparameters of a model with `edsnlp.tune`. - - - diff --git a/edsnlp/core/torch_component.py b/edsnlp/core/torch_component.py index e78b4fc6b6..0d50538a36 100644 --- a/edsnlp/core/torch_component.py +++ b/edsnlp/core/torch_component.py @@ -339,7 +339,14 @@ def compute_training_metrics( This is useful to compute averages when doing multi-gpu training or mini-batch accumulation since full denominators are not known during the forward pass. """ - return batch_output + return ( + { + **batch_output, + "loss": batch_output["loss"] / count, + } + if "loss" in batch_output + else batch_output + ) def module_forward(self, *args, **kwargs): # pragma: no cover """ @@ -348,6 +355,31 @@ def module_forward(self, *args, **kwargs): # pragma: no cover """ return torch.nn.Module.__call__(self, *args, **kwargs) + def preprocess_batch(self, docs: Sequence[Doc], supervision=False, **kwargs): + """ + Convenience method to preprocess a batch of documents. + Features corresponding to the same path are grouped together in a list, + under the same key. + + Parameters + ---------- + docs: Sequence[Doc] + Batch of documents + supervision: bool + Whether to extract supervision features or not + + Returns + ------- + Dict[str, Sequence[Any]] + The batch of features + """ + batch = [ + (self.preprocess_supervised(d) if supervision else self.preprocess(d)) + for d in docs + ] + batch = decompress_dict(list(batch_compress_dict(batch))) + return batch + def prepare_batch( self, docs: Sequence[Doc], @@ -372,11 +404,7 @@ def prepare_batch( ------- Dict[str, Sequence[Any]] """ - batch = [ - (self.preprocess_supervised(doc) if supervision else self.preprocess(doc)) - for doc in docs - ] - batch = decompress_dict(list(batch_compress_dict(batch))) + batch = self.preprocess_batch(docs, supervision=supervision) batch = self.collate(batch) batch = self.batch_to_device(batch, device=device) return batch diff --git a/edsnlp/metrics/doc_classif.py b/edsnlp/metrics/doc_classif.py new file mode 100644 index 0000000000..c9048eb3ca --- /dev/null +++ b/edsnlp/metrics/doc_classif.py @@ -0,0 +1,190 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +from spacy.tokens import Doc +from spacy.training import Example + +from edsnlp import registry +from edsnlp.metrics import make_examples + + +def doc_classification_metric( + examples: Union[Tuple[Iterable[Doc], Iterable[Doc]], Iterable[Example]], + label_attr: List[str], + micro_key: str = "micro", + macro_key: str = "macro", + filter_expr: Optional[str] = None, +) -> Dict[str, Dict[str, Any]]: + """ + Scores multi-head document-level classification (accuracy, precision, recall, F1) + for each head. + + Parameters + ---------- + examples: Examples + The examples to score, either a tuple of (golds, preds) or a list of + spacy.training.Example objects + label_attr: List[str] + The list of Doc._ attributes containing the labels for each head + micro_key: str + The key to use to store the micro-averaged results + macro_key: str + The key to use to store the macro-averaged results + filter_expr: str + The filter expression to use to filter the documents + + Returns + ------- + Dict[str, Dict[str, Any]] + Dictionary mapping head names to their respective metrics + """ + examples = make_examples(examples) + if filter_expr is not None: + filter_fn = eval(f"lambda doc: {filter_expr}") + examples = [eg for eg in examples if filter_fn(eg.reference)] + + all_head_results = {} + + for head_name in label_attr: + pred_labels = [] + gold_labels = [] + + for eg in examples: + pred = getattr(eg.predicted._, head_name, None) + gold = getattr(eg.reference._, head_name, None) + pred_labels.append(pred) + gold_labels.append(gold) + + labels = set(gold_labels) | set(pred_labels) + labels = {label for label in labels if label is not None} + head_results = {} + + for label in labels: + tp = sum( + 1 for p, g in zip(pred_labels, gold_labels) if p == label and g == label + ) + fp = sum( + 1 for p, g in zip(pred_labels, gold_labels) if p == label and g != label + ) + fn = sum( + 1 for p, g in zip(pred_labels, gold_labels) if g == label and p != label + ) + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = ( + (2 * precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) + + head_results[label] = { + "f": f1, + "p": precision, + "r": recall, + "tp": tp, + "fp": fp, + "fn": fn, + "support": tp + fn, + "positives": tp + fp, + } + + total_tp = sum(1 for p, g in zip(pred_labels, gold_labels) if p == g) + total_fp = sum(1 for p, g in zip(pred_labels, gold_labels) if p != g) + total_fn = total_fp + + micro_precision = ( + total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0 + ) + micro_recall = ( + total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0 + ) + micro_f1 = ( + (2 * micro_precision * micro_recall) / (micro_precision + micro_recall) + if (micro_precision + micro_recall) > 0 + else 0.0 + ) + accuracy = total_tp / len(pred_labels) if len(pred_labels) > 0 else 0.0 + + head_results[micro_key] = { + "accuracy": accuracy, + "f": micro_f1, + "p": micro_precision, + "r": micro_recall, + "tp": total_tp, + "fp": total_fp, + "fn": total_fn, + "support": len(gold_labels), + "positives": len(pred_labels), + } + + per_class_precisions = [head_results[label]["p"] for label in labels] + per_class_recalls = [head_results[label]["r"] for label in labels] + per_class_f1s = [head_results[label]["f"] for label in labels] + + macro_precision = ( + sum(per_class_precisions) / len(per_class_precisions) + if per_class_precisions + else 0.0 + ) + macro_recall = ( + sum(per_class_recalls) / len(per_class_recalls) + if per_class_recalls + else 0.0 + ) + macro_f1 = sum(per_class_f1s) / len(per_class_f1s) if per_class_f1s else 0.0 + + head_results[macro_key] = { + "f": macro_f1, + "p": macro_precision, + "r": macro_recall, + "support": len(labels), + "classes": len(labels), + } + + all_head_results[head_name] = head_results + + return all_head_results + + +@registry.metrics.register("eds.doc_classification") +class DocClassificationMetric: + def __init__( + self, + label_attr: List[str], + micro_key: str = "micro", + macro_key: str = "macro", + filter_expr: Optional[str] = None, + ): + """ + Multi-head document classification metric. + + Parameters + ---------- + label_attr: List[str] + List of Doc._ attributes containing the labels for each head + micro_key: str + The key to use to store the micro-averaged results + macro_key: str + The key to use to store the macro-averaged results + filter_expr: str + The filter expression to use to filter the documents + """ + self.label_attr = label_attr + self.micro_key = micro_key + self.macro_key = macro_key + self.filter_expr = filter_expr + + def __call__(self, *examples): + return doc_classification_metric( + examples, + label_attr=self.label_attr, + micro_key=self.micro_key, + macro_key=self.macro_key, + filter_expr=self.filter_expr, + ) + + +__all__ = [ + "doc_classification_metric", + "DocClassificationMetric", +] diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py index aea3f0f088..ee42396b39 100644 --- a/edsnlp/pipes/__init__.py +++ b/edsnlp/pipes/__init__.py @@ -82,5 +82,7 @@ from .trainable.embeddings.span_pooler.factory import create_component as span_pooler from .trainable.embeddings.transformer.factory import create_component as transformer from .trainable.embeddings.text_cnn.factory import create_component as text_cnn + from .trainable.embeddings.doc_pooler.factory import create_component as doc_pooler + from .trainable.doc_classifier.factory import create_component as doc_classifier from .misc.split import Split as split from .misc.explode import Explode as explode diff --git a/edsnlp/pipes/trainable/doc_classifier/__init__.py b/edsnlp/pipes/trainable/doc_classifier/__init__.py new file mode 100644 index 0000000000..549d2fc779 --- /dev/null +++ b/edsnlp/pipes/trainable/doc_classifier/__init__.py @@ -0,0 +1 @@ +from .factory import create_component diff --git a/edsnlp/pipes/trainable/doc_classifier/doc_classifier.py b/edsnlp/pipes/trainable/doc_classifier/doc_classifier.py new file mode 100644 index 0000000000..c506a02ae9 --- /dev/null +++ b/edsnlp/pipes/trainable/doc_classifier/doc_classifier.py @@ -0,0 +1,679 @@ +import os +import pickle +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Union + +import pandas as pd +import torch +import torch.nn as nn +from spacy.tokens import Doc +from typing_extensions import Literal, NotRequired, TypedDict + +import edsnlp +from edsnlp.core.pipeline import PipelineProtocol +from edsnlp.core.torch_component import BatchInput, TorchComponent +from edsnlp.pipes.base import BaseComponent +from edsnlp.pipes.trainable.embeddings.typing import ( + WordContextualizerComponent, + WordEmbeddingComponent, +) + +DocClassifierBatchInput = TypedDict( + "DocClassifierBatchInput", + { + "embedding": BatchInput, + "targets": NotRequired[Dict[str, torch.Tensor]], + }, +) + +DocClassifierBatchOutput = TypedDict( + "DocClassifierBatchOutput", + { + "loss": Optional[torch.Tensor], + "labels": Optional[Dict[str, torch.Tensor]], + }, +) + + +@edsnlp.registry.misc.register("focal_loss") +class FocalLoss(nn.Module): + """ + Focal Loss implementation for multi-class classification. + + Parameters + ---------- + alpha : torch.Tensor or float, optional + Class weights. If None, no weighting is applied + gamma : float, default=2.0 + Focusing parameter. Higher values give more weight to hard examples + reduction : str, default='mean' + Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' + """ + + def __init__( + self, + alpha: Optional[Union[torch.Tensor, float]] = None, + gamma: float = 2.0, + reduction: str = "mean", + ): + super().__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Forward pass + """ + ce_loss = torch.nn.functional.cross_entropy( + inputs, targets, weight=self.alpha, reduction="none" + ) + + pt = torch.exp(-ce_loss) + + focal_loss = (1 - pt) ** self.gamma * ce_loss + + if self.reduction == "mean": + return focal_loss.mean() + elif self.reduction == "sum": + return focal_loss.sum() + else: + return focal_loss + + +class TrainableDocClassifier( + TorchComponent[DocClassifierBatchOutput, DocClassifierBatchInput], + BaseComponent, +): + """ + The `eds.doc_classifier` component is a trainable document-level classifier. + In this context, document classification consists in predicting one or more + categorical labels at the **document level** (e.g. diagnosis code, discharge + status, or any metadata derived from the whole document). + + Unlike span classification, where predictions are attached to spans, the + document classifier attaches predictions to the `Doc` object itself. + + Architecture + ------------ + The model performs multi-head document classification by: + + 1. Calling a word/document embedding component `eds.doc_pooler` + to compute a pooled embedding for the document. + 2. Feeding the pooled embedding into one or more classification heads. + Each head is defined by a linear layer (optionally preceded by a + head-specific hidden layer with activation, dropout, and layer norm). + 3. Computing independent logits for each head. + 4. Training with a per-head loss (CrossEntropy or Focal), optionally using + class weights to handle imbalance. + 5. Aggregating head losses into a single training loss (simple average). + 6. During inference, assigning the predicted label for each head to + `doc._.labels[head_name]`. + + Each classification head is independent, so different tasks (e.g. + predicting ICD-10 category vs. mortality flag) can be trained jointly + on the same pooled embeddings. + + Examples + -------- + To create a document classifier component: + + ```python + import edsnlp, edsnlp.pipes as eds + + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.doc_classifier( + label_attr=["icd10", "mortality"], + labels={ + "icd10": "data/path_to_label_list_icd10.pkl", + "mortality": "data/path_to_label_list_mortality.pkl", + }, + num_classes={ + "icd10": 1000, + "mortality": 2, + }, + class_weights={ + "icd10": "data/path_to_class_weights_icd10.pkl", + "mortality": "data/path_to_class_weights_mortality.pkl", + }, + embedding=eds.doc_pooler( + pooling_mode="attention", + embedding=eds.transformer( + model="almanach/camembertav2-base", + window=256, + stride=128, + ), + ), + hidden_size=1024, + activation_mode="relu", + dropout_rate={ + "icd10": 0.05, + "mortality": 0.2, + }, + layer_norm=True, + loss="ce", + ), + name="doc_classifier", + ) + ``` + + After training, predictions are stored in the `Doc` object: + + ```python + doc = nlp("Patient was admitted with pneumonia and discharged alive.") + print(doc._.icd10, doc._.mortality) + # J18 alive + ``` + + Parameters + ---------- + nlp : Optional[PipelineProtocol] + The spaCy/edsnlp pipeline the component belongs to. + name : str, default="doc_classifier" + Component name. + embedding : WordEmbeddingComponent or WordContextualizerComponent + Embedding component (e.g. transformer + pooling). + Must expose an `output_size` attribute. + label_attr : List[str] + List of head names. Each head corresponds to a document-level attribute + (e.g. `["icd10", "mortality"]`). + num_classes : dict of str -> int, optional + Number of classes for each head. If not provided, inferred from labels. + label2id : dict of str -> dict[str,int], optional + Per-head mapping from label string to integer ID. + id2label : dict of str -> dict[int,str], optional + Reverse mapping (ID -> label string). + loss : {"ce", "focal"} or dict[str, {"ce","focal"}], default="ce" + Loss type, either shared or per-head. + labels : dict of str -> str (path), optional + Paths to pickle files containing label sets for each head. + class_weights : dict of str -> str (path), optional + Paths to pickle files containing class frequency dicts + (converted into class weights). + hidden_size : int or dict[str,int], optional + Hidden layer size (before classifier), shared or per-head. + If None, no hidden layer is used. + activation_mode : {"relu","gelu","silu"} or dict[str,str], default="relu" + Activation function for hidden layers, shared or per-head. + dropout_rate : float or dict[str,float], default=0.0 + Dropout rate after activation, shared or per-head. + layer_norm : bool or dict[str,bool], default=False + Whether to apply layer normalization in hidden layers, shared or per-head. + """ + + def __init__( + self, + nlp: Optional[PipelineProtocol] = None, + name: str = "doc_classifier", + *, + embedding: Union[WordEmbeddingComponent, WordContextualizerComponent], + label_attr: List[str], + num_classes: Optional[Dict[str, int]] = None, + label2id: Optional[Dict[str, Dict[str, int]]] = None, + id2label: Optional[Dict[str, Dict[int, str]]] = None, + loss: Union[Literal["ce", "focal"], Dict[str, Literal["ce", "focal"]]] = "ce", + labels: Optional[Dict[str, str]] = None, + class_weights: Optional[Dict[str, str]] = None, + hidden_size: Optional[Union[int, Dict[str, int]]] = None, + activation_mode: Union[ + Literal["relu", "gelu", "silu"], Dict[str, Literal["relu", "gelu", "silu"]] + ] = "relu", + dropout_rate: Optional[Union[float, Dict[str, float]]] = 0.0, + layer_norm: Optional[Union[bool, Dict[str, bool]]] = False, + ): + if not isinstance(label_attr, list) or len(label_attr) == 0: + raise ValueError("label_attr must be a non-empty list of strings") + + self.label_attr: List[str] = label_attr + self.head_names = label_attr + + self.num_classes = num_classes or {} + self.label2id = label2id or {head: {} for head in self.head_names} + self.id2label = id2label or {head: {} for head in self.head_names} + + self.labels_from_pickle = {} + if labels: + for head_name, labels_path in labels.items(): + if head_name in self.head_names: + head_labels = pd.read_pickle(labels_path) + self.labels_from_pickle[head_name] = head_labels + self.num_classes[head_name] = len(head_labels) + + self.class_weights = {} + if class_weights: + for head_name, weights_path in class_weights.items(): + if head_name in self.head_names: + self.class_weights[head_name] = pd.read_pickle(weights_path) + + if isinstance(loss, str): + self.loss_config = {head: loss for head in self.head_names} + else: + self.loss_config = loss + + if isinstance(hidden_size, (int, type(None))): + self.hidden_size_config = {head: hidden_size for head in self.head_names} + else: + self.hidden_size_config = hidden_size + + if isinstance(activation_mode, str): + self.activation_mode_config = { + head: activation_mode for head in self.head_names + } + else: + self.activation_mode_config = activation_mode + + if isinstance(dropout_rate, (float, type(None))): + self.dropout_rate_config = {head: dropout_rate for head in self.head_names} + else: + self.dropout_rate_config = dropout_rate + + if isinstance(layer_norm, bool): + self.layer_norm_config = {head: layer_norm for head in self.head_names} + else: + self.layer_norm_config = layer_norm + + super().__init__(nlp, name) + self.embedding = embedding + + if not hasattr(self.embedding, "output_size"): + raise ValueError( + "The embedding component must have an 'output_size' attribute." + ) + self.embedding_size = self.embedding.output_size + + if any(head in self.num_classes for head in self.head_names): + self.build_classifiers() + + def build_classifiers(self): + """ + Build classification heads for each task. + + For every head in `self.head_names`, creates: + - An optional hidden layer (`Linear + activation + dropout [+ layer norm]`). + - A final linear classifier projecting to `num_classes[head_name]`. + + All heads are stored in `nn.ModuleDict`s for modularity. + """ + self.classifiers = nn.ModuleDict() + self.hidden_layers = nn.ModuleDict() + self.activations = nn.ModuleDict() + self.norms = nn.ModuleDict() + self.dropouts = nn.ModuleDict() + + for head_name in self.head_names: + if head_name in self.num_classes: + hidden_size = self.hidden_size_config.get(head_name) + + if hidden_size: + self.hidden_layers[head_name] = torch.nn.Linear( + self.embedding_size, hidden_size + ) + + activation_mode = self.activation_mode_config.get(head_name, "relu") + self.activations[head_name] = { + "relu": nn.ReLU(), + "gelu": nn.GELU(), + "silu": nn.SiLU(), + }[activation_mode] + + if self.layer_norm_config.get(head_name, False): + self.norms[head_name] = nn.LayerNorm(hidden_size) + + dropout_rate = self.dropout_rate_config.get(head_name, 0.0) + self.dropouts[head_name] = nn.Dropout(dropout_rate) + + classifier_input_size = hidden_size + else: + classifier_input_size = self.embedding_size + + self.classifiers[head_name] = torch.nn.Linear( + classifier_input_size, self.num_classes[head_name] + ) + + def _compute_class_weights( + self, freq_dict: Dict[str, int], label2id: Dict[str, int] + ) -> torch.Tensor: + """ + Compute class weights from a frequency dictionary. + + Parameters + ---------- + freq_dict : dict[str, int] + Mapping from label string to its frequency. + label2id : dict[str, int] + Mapping from label string to class index. + + Returns + ------- + torch.Tensor + A weight vector aligned with label indices, where each weight is + proportional to the inverse of the label frequency. + """ + total_samples = sum(freq_dict.values()) + weights = torch.zeros(len(label2id)) + + for label, freq in freq_dict.items(): + if label in label2id: + weight = total_samples / (len(label2id) * freq) + weights[label2id[label]] = weight + + return weights + + def set_extensions(self) -> None: + """ + Register custom spaCy extensions for storing predictions. + + For each head in `self.head_names`, adds an attribute + `doc._.` if it does not already exist. + """ + super().set_extensions() + for head_name in self.head_names: + if not Doc.has_extension(head_name): + Doc.set_extension(head_name, default={}) + + def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]): + """ + Finalize initialization after gold data is available. + + - Builds label mappings (`label2id`, `id2label`) for each head if missing. + - Infers label sets from pickle files or scans gold data. + - Builds classifiers once `num_classes` are known. + - Initializes loss functions per head (CrossEntropy or Focal). + + Parameters + ---------- + gold_data : Iterable[Doc] + Training documents containing gold labels. + exclude : set + Components to exclude from initialization. + """ + for head_name in self.head_names: + if not self.label2id[head_name]: + if head_name in self.labels_from_pickle: + labels = self.labels_from_pickle[head_name] + else: + labels = set() + for doc in gold_data: + label = getattr(doc._, head_name, None) + if isinstance(label, str): + labels.add(label) + + if labels: + self.label2id[head_name] = { + label: i for i, label in enumerate(labels) + } + self.id2label[head_name] = { + i: label for i, label in enumerate(labels) + } + self.num_classes[head_name] = len(labels) + print(f"Head '{head_name}': {self.num_classes[head_name]} classes") + + self.build_classifiers() + + self.loss_fns = {} + for head_name in self.head_names: + weight_tensor = None + if head_name in self.class_weights: + weight_tensor = self._compute_class_weights( + self.class_weights[head_name], self.label2id[head_name] + ) + print(f"Head '{head_name}' - Using class weights: {weight_tensor}") + + loss_type = self.loss_config.get(head_name, "ce") + if loss_type == "ce": + self.loss_fns[head_name] = torch.nn.CrossEntropyLoss( + weight=weight_tensor + ) + elif loss_type == "focal": + self.loss_fns[head_name] = FocalLoss( + alpha=weight_tensor, gamma=2.0, reduction="mean" + ) + else: + raise ValueError(f"Unknown loss for head '{head_name}': {loss_type}") + + print("Loss functions initialized") + super().post_init(gold_data, exclude=exclude) + + def preprocess(self, doc: Doc) -> Dict[str, Any]: + """ + Preprocess a single document for inference. + + Parameters + ---------- + doc : Doc + Input spaCy/edsnlp `Doc`. + + Returns + ------- + dict + Dictionary containing the pooled embedding of the document. + """ + return {"embedding": self.embedding.preprocess(doc)} + + def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]: + """ + Preprocess a single document for training. + + Adds gold labels for each head to the embedding dict, mapping labels + to integer indices when possible. + + Parameters + ---------- + doc : Doc + Input document with gold labels stored in `doc._.`. + + Returns + ------- + dict + Dictionary with: + - `"embedding"` : document embedding + - `"targets_"` : gold target tensor for each head + """ + preps = self.preprocess(doc) + targets = {} + + for head_name in self.head_names: + label = getattr(doc._, head_name, None) + if label is None: + raise ValueError( + f"Document does not have a gold label in 'doc._.{head_name}'" + ) + + if isinstance(label, str) and head_name in self.label2id: + if label not in self.label2id[head_name]: + raise ValueError( + f"Label '{label}' not in label2id for head '{head_name}'." + ) + label = self.label2id[head_name][label] + + targets[head_name] = torch.tensor(label, dtype=torch.long) + return { + **preps, + **{ + f"targets_{head_name}": targets[head_name] + for head_name in self.head_names + }, + } + + def collate(self, batch: Dict[str, Sequence[Any]]) -> DocClassifierBatchInput: + """ + Collate a batch of preprocessed documents. + + Combines embeddings and per-head target tensors into a single batch. + + Parameters + ---------- + batch : dict + A list of per-document dicts returned by `preprocess_supervised`. + + Returns + ------- + DocClassifierBatchInput + Batched embeddings and optional targets. + """ + embeddings = self.embedding.collate(batch["embedding"]) + batch_input: DocClassifierBatchInput = {"embedding": embeddings} + + collated_targets = {} + for head_name in self.head_names: + key = f"targets_{head_name}" + if key in batch: + collated_targets[head_name] = torch.stack(batch[key]) + if collated_targets: + batch_input["targets"] = collated_targets + + return batch_input + + def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput: + """ + Forward pass through the model. + + - Computes shared embeddings. + - Applies each classification head independently. + - Computes per-head losses (if targets provided) and averages them. + - Otherwise, returns predicted class indices for each head. + + Parameters + ---------- + batch : DocClassifierBatchInput + Batched embeddings and optional targets. + + Returns + ------- + DocClassifierBatchOutput + Dict with `"loss"` (training mode) or `"labels"` (inference mode). + """ + pooled = self.embedding(batch["embedding"]) + shared_embeddings = pooled["embeddings"] + + head_logits = {} + for head_name in self.head_names: + if head_name in self.classifiers: + x = shared_embeddings + + if head_name in self.hidden_layers: + x = self.hidden_layers[head_name](x) + x = self.activations[head_name](x) + if head_name in self.norms: + x = self.norms[head_name](x) + x = self.dropouts[head_name](x) + + head_logits[head_name] = self.classifiers[head_name](x) + + output: DocClassifierBatchOutput = {} + + if "targets" in batch: + head_losses = [] + for head_name in self.head_names: + if head_name in head_logits and head_name in batch["targets"]: + logits = head_logits[head_name] + targets = batch["targets"][head_name].to(logits.device) + + loss_fn = self.loss_fns[head_name] + if hasattr(loss_fn, "weight") and loss_fn.weight is not None: + loss_fn.weight = loss_fn.weight.to(logits.device) + + head_loss = loss_fn(logits, targets) + head_losses.append(head_loss) + + output["loss"] = torch.stack(head_losses).mean() if head_losses else None + output["labels"] = None + else: + head_predictions = { + head_name: torch.argmax(logits, dim=-1) + for head_name, logits in head_logits.items() + } + output["loss"] = None + output["labels"] = head_predictions + + return output + + def postprocess(self, docs, results, input): + """ + Attach predictions to documents after inference. + + For each head, predicted labels are mapped back to strings using + `id2label` and stored in `doc._.`. + + Parameters + ---------- + docs : list[Doc] + Documents processed by the pipeline. + results : dict + Output of the forward pass (`"labels"`). + input : dict + Input batch (unused). + + Returns + ------- + list[Doc] + The same documents with predictions stored in extensions. + """ + labels_dict = results["labels"] + if labels_dict is None: + return docs + + for head_name, labels in labels_dict.items(): + if isinstance(labels, torch.Tensor): + labels = labels.tolist() + + for doc, label in zip(docs, labels): + if head_name in self.id2label and isinstance(label, int): + label = self.id2label[head_name].get(label, label) + setattr(doc._, head_name, label) + + return docs + + def to_disk(self, path, *, exclude=set()): + """ + Save the classifier state to disk. + + Stores: + - Label attributes and mappings + - Per-head configuration (loss, hidden size, dropout, etc.) + + Parameters + ---------- + path : Path + Directory where files are saved. + exclude : set, optional + Components to exclude from saving. + """ + repr_id = object.__repr__(self) + if repr_id in exclude: + return + os.makedirs(path, exist_ok=True) + data_path = path / "multi_head_data.pkl" + with open(data_path, "wb") as f: + pickle.dump( + { + "label_attr": self.label_attr, + "label2id": self.label2id, + "id2label": self.id2label, + "loss_config": self.loss_config, + "hidden_size_config": self.hidden_size_config, + "activation_mode_config": self.activation_mode_config, + "dropout_rate_config": self.dropout_rate_config, + "layer_norm_config": self.layer_norm_config, + }, + f, + ) + return super().to_disk(path, exclude=exclude) + + def from_disk(self, path, exclude=tuple()): + repr_id = object.__repr__(self) + if repr_id in exclude: + return + data_path = path / "multi_head_data.pkl" + with open(data_path, "rb") as f: + data = pickle.load(f) + self.label_attr = data.get("label_attr", []) + self.head_names = self.label_attr + self.label2id = data.get("label2id", {}) + self.id2label = data.get("id2label", {}) + self.loss_config = data.get("loss_config", {}) + self.hidden_size_config = data.get("hidden_size_config", {}) + self.activation_mode_config = data.get("activation_mode_config", {}) + self.dropout_rate_config = data.get("dropout_rate_config", {}) + self.layer_norm_config = data.get("layer_norm_config", {}) + super().from_disk(path, exclude=exclude) diff --git a/edsnlp/pipes/trainable/doc_classifier/factory.py b/edsnlp/pipes/trainable/doc_classifier/factory.py new file mode 100644 index 0000000000..a029815b5b --- /dev/null +++ b/edsnlp/pipes/trainable/doc_classifier/factory.py @@ -0,0 +1,9 @@ +from edsnlp import registry + +from .doc_classifier import TrainableDocClassifier + +create_component = registry.factory.register( + "eds.doc_classifier", + assigns=["doc._.predicted_class"], + deprecated=[], +)(TrainableDocClassifier) diff --git a/edsnlp/pipes/trainable/embeddings/doc_pooler/__init__.py b/edsnlp/pipes/trainable/embeddings/doc_pooler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/edsnlp/pipes/trainable/embeddings/doc_pooler/doc_pooler.py b/edsnlp/pipes/trainable/embeddings/doc_pooler/doc_pooler.py new file mode 100644 index 0000000000..11571dcb3b --- /dev/null +++ b/edsnlp/pipes/trainable/embeddings/doc_pooler/doc_pooler.py @@ -0,0 +1,126 @@ +from typing import Any, Dict, Optional + +import torch +from spacy.tokens import Doc +from typing_extensions import Literal, TypedDict + +from edsnlp.core.pipeline import Pipeline +from edsnlp.core.torch_component import BatchInput +from edsnlp.pipes.base import BaseComponent +from edsnlp.pipes.trainable.embeddings.typing import WordEmbeddingComponent + +DocPoolerBatchInput = TypedDict( + "DocPoolerBatchInput", + { + "embedding": BatchInput, + "mask": torch.Tensor, + "stats": Dict[str, Any], + }, +) + +DocPoolerBatchOutput = TypedDict( + "DocPoolerBatchOutput", + { + "embeddings": torch.Tensor, + }, +) + + +class DocPooler(WordEmbeddingComponent, BaseComponent): + """ + Pools word embeddings over the entire document to produce + a single embedding per doc. + + Parameters + ---------- + nlp: Pipeline + The pipeline object + name: str + Name of the component + embedding : WordEmbeddingComponent + The word embedding component + pooling_mode: Literal["max", "sum", "mean"] + How word embeddings are aggregated into a single embedding per document. + hidden_size : Optional[int] + The size of the hidden layer. If None, no projection is done. + """ + + def __init__( + self, + nlp: Optional[Pipeline] = None, + name: str = "document_pooler", + *, + embedding: WordEmbeddingComponent, + pooling_mode: Literal["max", "sum", "mean", "cls", "attention"] = "mean", + ): + super().__init__(nlp, name) + self.embedding = embedding + self.pooling_mode = pooling_mode + self.output_size = embedding.output_size + + # Add attention layer if needed + if pooling_mode == "attention": + self.attention = torch.nn.Linear(self.output_size, 1) + + def preprocess(self, doc: Doc, **kwargs) -> Dict[str, Any]: + embedding_out = self.embedding.preprocess(doc, **kwargs) + return { + "embedding": embedding_out, + "stats": {"doc_length": len(doc)}, + } + + def collate(self, batch: Dict[str, Any]) -> DocPoolerBatchInput: + embedding_batch = self.embedding.collate(batch["embedding"]) + stats = batch["stats"] + return { + "embedding": embedding_batch, + "stats": { + "doc_length": sum(stats["doc_length"]) + }, # <-- sum(...) pour aggréger les comptes par doc en un compte par batch + } + + def forward(self, batch: DocPoolerBatchInput) -> DocPoolerBatchOutput: + """ + Forward pass: compute document embeddings using the selected pooling strategy + """ + embeds = self.embedding(batch["embedding"])["embeddings"].refold( + "context", "word" + ) + device = embeds.device + + if self.pooling_mode == "cls": + pooled = self.embedding(batch["embedding"])["cls"].to(device) + return {"embeddings": pooled} + + mask = embeds.mask + + if self.pooling_mode == "attention": + attention_weights = self.attention(embeds) # (batch_size, seq_len, 1) + attention_weights = attention_weights.squeeze(-1) # (batch_size, seq_len) + + attention_weights = attention_weights.masked_fill(~mask, float("-inf")) + + attention_weights = torch.softmax(attention_weights, dim=1) + + attention_weights = attention_weights.unsqueeze( + -1 + ) # (batch_size, seq_len, 1) + pooled = (embeds * attention_weights).sum(dim=1) # (batch_size, embed_dim) + + else: + mask_expanded = mask.unsqueeze(-1) + masked_embeds = embeds * mask_expanded + sum_embeds = masked_embeds.sum(dim=1) + + if self.pooling_mode == "mean": + valid_counts = mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = sum_embeds / valid_counts + elif self.pooling_mode == "max": + masked_embeds = embeds.masked_fill(~mask_expanded, float("-inf")) + pooled, _ = masked_embeds.max(dim=1) + elif self.pooling_mode == "sum": + pooled = sum_embeds + else: + raise ValueError(f"Unknown pooling mode: {self.pooling_mode}") + + return {"embeddings": pooled} diff --git a/edsnlp/pipes/trainable/embeddings/doc_pooler/factory.py b/edsnlp/pipes/trainable/embeddings/doc_pooler/factory.py new file mode 100644 index 0000000000..fbed45a685 --- /dev/null +++ b/edsnlp/pipes/trainable/embeddings/doc_pooler/factory.py @@ -0,0 +1,9 @@ +from edsnlp import registry + +from .doc_pooler import DocPooler + +create_component = registry.factory.register( + "eds.doc_pooler", + assigns=[], + deprecated=[], +)(DocPooler) diff --git a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py index 5e58cd9bb6..3032c574f8 100644 --- a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py +++ b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py @@ -22,21 +22,34 @@ "SpanPoolerBatchInput", { "embedding": BatchInput, - "begins": ft.FoldedTensor, - "ends": ft.FoldedTensor, - "sequence_idx": torch.Tensor, - "stats": TypedDict("SpanPoolerBatchStats", {"spans": int}), + "span_begins": ft.FoldedTensor, + "span_ends": ft.FoldedTensor, + "span_contexts": ft.FoldedTensor, + "item_indices": torch.LongTensor, + "span_offsets": torch.LongTensor, + "span_indices": torch.LongTensor, + "stats": Dict[str, int], }, ) """ -embeds: torch.FloatTensor - Token embeddings to predict the tags from -begins: torch.LongTensor +Attributes +---------- +embedding: BatchInput + The input batch for the word embedding component +span_begins: ft.FoldedTensor Begin offsets of the spans -ends: torch.LongTensor +span_ends: ft.FoldedTensor End offsets of the spans -sequence_idx: torch.LongTensor - Sequence (cf Embedding spans) index of the spans +span_contexts: ft.FoldedTensor + Sequence/context index of the spans +item_indices: torch.LongTensor + Indices of the span's tokens in the tokens embedding output +span_offsets: torch.LongTensor + Offsets of the spans in the flattened span tokens +span_indices: torch.LongTensor + Span index of each token in the flattened span tokens +stats: Dict[str, int] + Statistics about the batch, e.g. number of spans """ SpanPoolerBatchOutput = TypedDict( @@ -45,6 +58,12 @@ "embeddings": ft.FoldedTensor, }, ) +""" +Attributes +---------- +embeddings: ft.FoldedTensor + The output span embeddings, with foldable dimensions ("sample", "span") +""" class SpanPooler(SpanEmbeddingComponent, BaseComponent): @@ -61,8 +80,14 @@ class SpanPooler(SpanEmbeddingComponent, BaseComponent): Name of the component embedding : WordEmbeddingComponent The word embedding component - pooling_mode: Literal["max", "sum", "mean"] - How word embeddings are aggregated into a single embedding per span. + pooling_mode: Literal["max", "sum", "mean", "attention"] + How word embeddings are aggregated into a single embedding per span: + + - "max": max pooling + - "sum": sum pooling + - "mean": mean pooling + - "attention": attention pooling, where attention scores are computed using a + linear layer followed by a softmax over the tokens in the span. hidden_size : Optional[int] The size of the hidden layer. If None, no projection is done and the output of the span pooler is used directly. @@ -74,7 +99,9 @@ def __init__( name: str = "span_pooler", *, embedding: WordEmbeddingComponent, - pooling_mode: Literal["max", "sum", "mean"] = "mean", + pooling_mode: Literal["max", "sum", "mean", "attention"] = "mean", + activation: Optional[Literal["relu", "gelu", "silu"]] = None, + norm: Optional[Literal["layernorm", "batchnorm"]] = None, hidden_size: Optional[int] = None, span_getter: Any = None, ): @@ -99,11 +126,35 @@ def __init__( self.pooling_mode = pooling_mode self.span_getter = span_getter self.embedding = embedding - self.projector = ( - torch.nn.Linear(self.embedding.output_size, hidden_size) - if hidden_size is not None - else torch.nn.Identity() - ) + self.activation = activation + self.projector = torch.nn.Sequential() + if hidden_size is not None: + self.projector.append( + torch.nn.Linear(self.embedding.output_size, hidden_size) + ) + if activation is not None: + self.projector.append( + { + "relu": torch.nn.ReLU, + "gelu": torch.nn.GELU, + "silu": torch.nn.SiLU, + }[activation]() + ) + if norm is not None: + self.projector.append( + { + "layernorm": torch.nn.LayerNorm, + "batchnorm": torch.nn.BatchNorm1d, + }[norm]( + hidden_size + if hidden_size is not None + else self.embedding.output_size + ) + ) + if self.pooling_mode in {"attention"}: + self.attention_scorer = torch.nn.Linear( + self.embedding.output_size, 1, bias=False + ) def feed_forward(self, span_embeds: torch.Tensor) -> torch.Tensor: return self.projector(span_embeds) @@ -112,18 +163,23 @@ def preprocess( self, doc: Doc, *, - spans: Optional[Sequence[Span]] = None, + spans: Optional[Sequence[Span]], contexts: Optional[Sequence[Span]] = None, pre_aligned: bool = False, **kwargs, ) -> Dict[str, Any]: - contexts = contexts if contexts is not None else [doc[:]] + if contexts is None: + contexts = [doc[:]] * len(spans) + pre_aligned = True - sequence_idx = [] + context_indices = [] begins = [] ends = [] - contexts_to_idx = {span: i for i, span in enumerate(contexts)} + contexts_to_idx = {} + for ctx in contexts: + if ctx not in contexts_to_idx: + contexts_to_idx[ctx] = len(contexts_to_idx) assert not pre_aligned or len(spans) == len(contexts), ( "When `pre_aligned` is True, the number of spans and contexts must be the " "same." @@ -140,52 +196,96 @@ def preprocess( f"span: {[s.text for s in ctx]}" ) start = ctx[0].start - sequence_idx.append(contexts_to_idx[ctx[0]]) + context_indices.append(contexts_to_idx[ctx[0]]) begins.append(span.start - start) ends.append(span.end - start) return { "begins": begins, "ends": ends, - "sequence_idx": sequence_idx, + "span_to_ctx_idx": context_indices, "num_sequences": len(contexts), - "embedding": self.embedding.preprocess(doc, contexts=contexts, **kwargs), + "embedding": self.embedding.preprocess( + doc, contexts=list(contexts_to_idx), **kwargs + ), "stats": {"spans": len(begins)}, } def collate(self, batch: Dict[str, Sequence[Any]]) -> SpanPoolerBatchInput: - sequence_idx = [] - offset = 0 - for indices, seq_length in zip(batch["sequence_idx"], batch["num_sequences"]): - sequence_idx.extend([offset + idx for idx in indices]) - offset += seq_length + embedding_batch = self.embedding.collate(batch["embedding"]) + embed_structure = embedding_batch["out_structure"] + ft_kw = dict( + data_dims=("span",), + full_names=("sample", "span"), + dtype=torch.long, + ) + begins = ft.as_folded_tensor(batch["begins"], **ft_kw) + ends = ft.as_folded_tensor(batch["ends"], **ft_kw) + span_to_ctx_idx = [] + total_num_ctx = 0 + for i, (ctx_indices, num_ctx) in enumerate( + zip(batch["span_to_ctx_idx"], embed_structure["context"]) + ): + span_to_ctx_idx.append([idx + total_num_ctx for idx in ctx_indices]) + total_num_ctx += num_ctx + flat_span_to_ctx_idx = ft.as_folded_tensor(span_to_ctx_idx, **ft_kw) + item_indices, span_offsets, span_indices = embed_structure.make_indices_ranges( + begins=(flat_span_to_ctx_idx, begins), + ends=(flat_span_to_ctx_idx, ends), + indice_dims=("context", "word"), + ) collated: SpanPoolerBatchInput = { - "embedding": self.embedding.collate(batch["embedding"]), - "begins": ft.as_folded_tensor( - batch["begins"], - data_dims=("span",), - full_names=("sample", "span"), - dtype=torch.long, - ), - "ends": ft.as_folded_tensor( - batch["ends"], - data_dims=("span",), - full_names=("sample", "span"), - dtype=torch.long, - ), - "sequence_idx": torch.as_tensor(sequence_idx), + "embedding": embedding_batch, + "span_begins": begins, + "span_ends": ends, + "span_contexts": flat_span_to_ctx_idx, + "item_indices": item_indices, + "span_offsets": begins.with_data(span_offsets), + "span_indices": span_indices, "stats": {"spans": sum(batch["stats"]["spans"])}, } return collated + def _pool_spans(self, embeds, span_indices, span_offsets, item_indices=None): + dev = span_offsets.device + dim = embeds.size(-1) + embeds = embeds.as_tensor().view(-1, dim) + N = span_offsets.numel() # number of spans + + if self.pooling_mode == "attention": + if item_indices is not None: + embeds = embeds[item_indices] + weights = self.attention_scorer(embeds) + # compute max for softmax stability + dtype = weights.dtype + max_weights = torch.full((N, 1), float("-inf"), device=dev, dtype=dtype) + max_weights.index_reduce_(0, span_indices, weights, reduce="amax") + # softmax numerator + exp_scores = torch.exp(weights - max_weights[span_indices]) + # softmax denominator + denom = torch.zeros((N, 1), device=dev, dtype=exp_scores.dtype) + denom.index_add_(0, span_indices, exp_scores) + # softmax output = embeds * weight num / weight denom + weighted_embeds = embeds * exp_scores / denom[span_indices] + span_embeds = torch.zeros((N, dim), device=dev, dtype=embeds.dtype) + span_embeds.index_add_(0, span_indices, weighted_embeds) + span_embeds = span_offsets.with_data(span_embeds) + else: + span_embeds = torch.nn.functional.embedding_bag( # type: ignore + input=torch.arange(len(embeds), device=dev) + if item_indices is None + else item_indices, + weight=embeds, + offsets=span_offsets, + mode=self.pooling_mode, + ) + span_embeds = self.feed_forward(span_embeds) + return span_embeds + # noinspection SpellCheckingInspection def forward(self, batch: SpanPoolerBatchInput) -> SpanPoolerBatchOutput: """ - Apply the span classifier module to the document embeddings and given spans to: - - compute the loss - - and/or predict the labels of spans - If labels are predicted, they are assigned to the `additional_outputs` - dictionary. + Forward pass of the component, returns span embeddings. Parameters ---------- @@ -194,38 +294,26 @@ def forward(self, batch: SpanPoolerBatchInput) -> SpanPoolerBatchOutput: Returns ------- - BatchOutput + SpanPoolerBatchOutput """ - device = next(self.parameters()).device - if len(batch["begins"]) == 0: - span_embeds = torch.empty(0, self.output_size, device=device) + if len(batch["span_begins"]) == 0: return { - "embeddings": batch["begins"].with_data(span_embeds), + "embeddings": batch["span_begins"].with_data( + torch.empty( + 0, + self.output_size, + device=batch["item_indices"].device, + ) + ), } embeds = self.embedding(batch["embedding"])["embeddings"] - _, n_words, dim = embeds.shape - device = embeds.device - - flat_begins = n_words * batch["sequence_idx"] + batch["begins"].as_tensor() - flat_ends = n_words * batch["sequence_idx"] + batch["ends"].as_tensor() - flat_embeds = embeds.view(-1, dim) - flat_indices = torch.cat( - [ - torch.arange(b, e, device=device) - for b, e in zip(flat_begins.cpu().tolist(), flat_ends.cpu().tolist()) - ] - ).to(device) - offsets = (flat_ends - flat_begins).cumsum(0).roll(1) - offsets[0] = 0 - span_embeds = torch.nn.functional.embedding_bag( # type: ignore - input=flat_indices, - weight=flat_embeds, - offsets=offsets, - mode=self.pooling_mode, + span_embeds = self._pool_spans( + embeds, + span_indices=batch["span_indices"], + span_offsets=batch["span_offsets"], + item_indices=batch["item_indices"], ) - span_embeds = self.feed_forward(span_embeds) - return { - "embeddings": batch["begins"].with_data(span_embeds), + "embeddings": batch["span_begins"].with_data(span_embeds), } diff --git a/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py b/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py index c1de49a562..92fff74928 100644 --- a/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py +++ b/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence +from typing import Any, Dict, Optional, Sequence import torch from typing_extensions import Literal, TypedDict @@ -87,6 +87,10 @@ def __init__( normalize=normalize, ) + def collate(self, batch: Dict[str, Any]) -> BatchInput: + emb = self.embedding.collate(batch["embedding"]) + return {"embedding": emb, "out_structure": emb["out_structure"]} + def forward(self, batch: BatchInput) -> WordEmbeddingBatchOutput: """ Encode embeddings with a 1d convolutional network diff --git a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py index 4d830858e8..f1d0e186fb 100644 --- a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py +++ b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py @@ -34,6 +34,8 @@ }, ) """ +Attributes +---------- input_ids: FoldedTensor Tokenized input (prompt + text) to embed word_indices: torch.LongTensor @@ -51,6 +53,8 @@ }, ) """ +Attributes +---------- embeddings: FoldedTensor The embeddings of the words """ @@ -101,9 +105,23 @@ class Transformer(WordEmbeddingComponent[TransformerBatchInput]): stride=96, ), ) + + doc1 = nlp.make_doc("My name is Michael.") + doc2 = nlp.make_doc("And I am the best boss in the world.") + prep = nlp.pipes.transformer.preprocess_batch([doc1, doc2]) + batch = nlp.pipes.transformer.collate(prep) + res = nlp.pipes.transformer(batch) + + # Embeddings are flattened by default + print(res["embeddings"].shape) + # Out: torch.Size([15, 128]) + + # But they can be refolded to materialize the sample dimension + print(res["embeddings"].refold("sample", "word").shape) + # Out: torch.Size([2, 10, 128]) ``` - You can then compose this embedding with a task specific component such as + You can compose this embedding with a task specific component such as `eds.ner_crf`. Parameters @@ -131,9 +149,25 @@ class Transformer(WordEmbeddingComponent[TransformerBatchInput]): If "auto", the component will try to estimate the maximum number of tokens that can be processed by the model on the current device at a given time. - span_getter: Optional[SpanGetterArg] - Which spans of the document should be embedded. Defaults to the full document - if None. + new_tokens: Optional[List[Tuple[str, str]]] + A list of (pattern, replacement) tuples to add to the tokenizer. The pattern + should be a valid regular expression. The replacement should be a string. + + This can be used to add new tokens to the tokenizer that are not present in the + original vocabulary. For example, if you want to add a new token for new lines + you can use the following: + + ```python + new_tokens = [("\\n", "⏎")] + ``` + quantization: Optional[BitsAndBytesConfig] + Quantization configuration to use for the model. If None, no quantization + will be applied. This requires the `bitsandbytes` library to be installed. + word_pooling_mode: Literal["mean", False] + If "mean", the embeddings of the wordpieces corresponding to each word will be + averaged to produce a single embedding per word. If False, the embeddings of the + wordpieces will be returned as a FoldedTensor with an additional "token" + dimension. (default: "mean") """ def __init__( @@ -149,6 +183,7 @@ def __init__( span_getter: Optional[SpanGetterArg] = None, new_tokens: Optional[List[Tuple[str, str]]] = [], quantization: Optional[BitsAndBytesConfig] = None, + word_pooling_mode: Literal["mean", False] = "mean", **kwargs, ): super().__init__(nlp, name) @@ -167,6 +202,7 @@ def __init__( kwargs["quantization_config"] = quantization self.transformer = AutoModel.from_pretrained(model, **kwargs) + self.word_pooling_mode = word_pooling_mode try: self.tokenizer = AutoTokenizer.from_pretrained(model) except (HTTPException, ConnectionError): # pragma: no cover @@ -353,6 +389,7 @@ def collate(self, batch): empty_word_indices = [] overlap = self.window - stride word_offset = 0 + word_sizes = [] all_word_wp_offset = 0 for ( sample_text_input_ids, @@ -422,23 +459,38 @@ def collate(self, batch): ] ] word_indices.extend(word_wp_indices) + word_sizes.append(length) word_wp_offset += length word_offset += 1 all_word_wp_offset += word_wp_offset + word_offsets = ft.as_folded_tensor( + word_offsets, + data_dims=("word",), + full_names=("sample", "context", "word"), + dtype=torch.long, + ) + out_structure = ( + ft.FoldedTensorLayout( + [ + *word_offsets.lengths, + word_sizes, + ], + full_names=("sample", "context", "word", "token"), + data_dims=("token",), + ) + if not self.word_pooling_mode + else word_offsets.lengths + ) return { + "out_structure": out_structure, "input_ids": ft.as_folded_tensor( input_ids, data_dims=("context", "subword"), full_names=("context", "subword"), dtype=torch.long, ), - "word_offsets": ft.as_folded_tensor( - word_offsets, - data_dims=("word",), - full_names=("sample", "context", "word"), - dtype=torch.long, - ), + "word_offsets": word_offsets, "word_indices": torch.as_tensor(word_indices, dtype=torch.long), "empty_word_indices": torch.as_tensor(empty_word_indices, dtype=torch.long), "stats": { @@ -480,7 +532,7 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput: max_windows = max(1, max_tokens // input_ids.size(1)) total_windows = input_ids.size(0) try: - wordpiece_embeddings = [ + wp_embs = [ self.transformer.base_model( **{ k: None if v is None else v[offset : offset + max_windows] @@ -490,11 +542,7 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput: for offset in range(0, total_windows, max_windows) ] - wordpiece_embeddings = ( - torch.cat(wordpiece_embeddings, dim=0) - if len(wordpiece_embeddings) > 1 - else wordpiece_embeddings[0] - ) + wp_embs = torch.cat(wp_embs, dim=0) if len(wp_embs) > 1 else wp_embs[0] if auto_batch_size: # pragma: no cover batch_mem = torch.cuda.max_memory_allocated(device) @@ -520,26 +568,35 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput: # mask = batch["mask"].clone() # word_embeddings = torch.zeros( - # (mask.size(0), mask.size(1), wordpiece_embeddings.size(2)), + # (mask.size(0), mask.size(1), wp_embs.size(2)), # dtype=torch.float, # device=device, # ) # embeddings_plus_empty = torch.cat( # [ - # wordpiece_embeddings.view(-1, wordpiece_embeddings.size(2)), + # wp_embs.view(-1, wp_embs.size(2)), # self.empty_word_embedding, # ], # dim=0, # ) - word_embeddings = torch.nn.functional.embedding_bag( - input=batch["word_indices"], - weight=wordpiece_embeddings.reshape(-1, wordpiece_embeddings.size(2)), - offsets=batch["word_offsets"], - ) - word_embeddings[batch["empty_word_indices"]] = self.empty_word_embedding - return { - "embeddings": word_embeddings.refold("context", "word"), - } + if self.word_pooling_mode == "mean": + word_embeddings = torch.nn.functional.embedding_bag( + input=batch["word_indices"], + weight=wp_embs.reshape(-1, wp_embs.size(2)), + offsets=batch["word_offsets"], + ) + word_embeddings[batch["empty_word_indices"]] = self.empty_word_embedding + return { + "embeddings": word_embeddings, + "cls": wp_embs[:, 0, :], + } + else: + wp_embs = wp_embs.reshape(-1, self.output_size)[batch["word_indices"]] + wp_embs = ft.as_folded_tensor(wp_embs, lengths=batch["out_structure"]) + return { + "embeddings": wp_embs, + "cls": wp_embs[:, 0, :], + } @staticmethod def align_words_with_trf_tokens(doc, trf_char_indices): diff --git a/edsnlp/pipes/trainable/embeddings/typing.py b/edsnlp/pipes/trainable/embeddings/typing.py index c044cb7951..94d4891164 100644 --- a/edsnlp/pipes/trainable/embeddings/typing.py +++ b/edsnlp/pipes/trainable/embeddings/typing.py @@ -30,8 +30,7 @@ def preprocess( *, contexts: Optional[List[Span]], **kwargs, - ) -> Dict[str, Any]: - ... + ) -> Dict[str, Any]: ... class WordContextualizerComponent( @@ -67,5 +66,4 @@ def preprocess( contexts: Optional[List[Span]], pre_aligned: bool = False, **kwargs, - ) -> Dict[str, Any]: - ... + ) -> Dict[str, Any]: ... diff --git a/edsnlp/resources/verbs.csv.gz b/edsnlp/resources/verbs.csv.gz index b05fb4eeff..b74c8587c2 100644 Binary files a/edsnlp/resources/verbs.csv.gz and b/edsnlp/resources/verbs.csv.gz differ diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index 634e2296ab..81c7fb8c06 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -695,6 +695,8 @@ def train( for td in train_data if td.pipe_names is None or set(td.pipe_names) & set(pipe_names) ] + for td in phase_training_data: + print("phase_training_data", td) if len(phase_training_data) == 0: raise ValueError( @@ -849,9 +851,7 @@ def train( for idx, (batch, batch_pipe_names) in enumerate( zip(batches, batches_pipe_names) ): - cache_ctx = ( - nlp.cache() if len(batch_pipe_names) > 1 else nullcontext() - ) + cache_ctx = nlp.cache() no_sync_ctx = ( accelerator.no_sync(trained_pipes) if idx < len(batches) - 1 diff --git a/edsnlp/tune.py b/edsnlp/tune.py index c7c46d6810..24b2ca891d 100644 --- a/edsnlp/tune.py +++ b/edsnlp/tune.py @@ -260,9 +260,15 @@ def update_config( current_config = config for key in p_path[:-1]: - if key not in current_config: - raise KeyError(f"Path '{key}' not found in config.") - current_config = current_config[key] + try: + current_config = current_config[key] + except KeyError: + try: + current_config = current_config[int(key)] + except (KeyError, ValueError): + raise KeyError( + f"Path '{key}' not found in config ({current_config})" + ) current_config[p_path[-1]] = value if resolve: @@ -595,7 +601,7 @@ def compute_remaining_n_trials_possible( remaining_gpu_time, compute_time_per_trial(study, ema=True) ) return n_trials - except ValueError: + except ValueError: # pragma: no cover return 0 diff --git a/pyproject.toml b/pyproject.toml index 1ccd9df4bf..52d77ce724 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ docs-no-ml = [ ml = [ "rich-logger>=0.3.1", "torch>=1.13.0; python_version>='3.9'", - "foldedtensor>=0.4.0", + "foldedtensor @ git+https://github.com/aphp/foldedtensor.git@indice-mapping#egg=foldedtensor", "safetensors>=0.3.0; python_version>='3.8'", "safetensors>=0.3.0,<0.5.0; python_version<'3.8'", "transformers>=4.0.0", @@ -258,6 +258,8 @@ where = ["."] "eds.span_classifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" "eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component" "eds.biaffine_dep_parser" = "edsnlp.pipes.trainable.biaffine_dep_parser.factory:create_component" +"eds.doc_pooler" = "edsnlp.pipes.trainable.embeddings.doc_pooler.factory:create_component" +"eds.doc_classifier" = "edsnlp.pipes.trainable.doc_classifier.factory:create_component" [project.entry-points."edsnlp_schedules"] "linear" = "edsnlp.training.optimizer:LinearSchedule" @@ -268,6 +270,7 @@ where = ["."] "eds.ner_overlap" = "edsnlp.metrics.ner:NerOverlapMetric" "eds.span_attribute" = "edsnlp.metrics.span_attribute:SpanAttributeMetric" "eds.dep_parsing" = "edsnlp.metrics.dep_parsing:DependencyParsingMetric" +"eds.doc_classif" = "edsnlp.metrics.doc_classif:DocClassificationMetric" # Deprecated "eds.ner_exact_metric" = "edsnlp.metrics.ner:NerExactMetric" diff --git a/tests/pipelines/trainable/dummy_embeddings.py b/tests/pipelines/trainable/dummy_embeddings.py new file mode 100644 index 0000000000..c1cd3d3fd1 --- /dev/null +++ b/tests/pipelines/trainable/dummy_embeddings.py @@ -0,0 +1,123 @@ +from typing import List, Optional + +import pytest + +pytest.importorskip("torch") + +import foldedtensor as ft +import torch +from typing_extensions import Literal + +from edsnlp import Pipeline +from edsnlp.pipes.trainable.embeddings.typing import WordEmbeddingComponent + + +class DummyEmbeddings(WordEmbeddingComponent[dict]): + """ + For each word, embedding = (word idx in sent) * [1, 1, ..., 1] (size = dim) + """ + + def __init__( + self, + nlp: Optional[Pipeline] = None, + name: str = "fixed_embeddings", + word_pooling_mode: Literal["mean", False] = "mean", + *, + dim: int, + ): + super().__init__(nlp, name) + self.output_size = int(dim) + self.word_pooling_mode = word_pooling_mode + + def preprocess(self, doc, *, contexts=None, prompts=()): + if contexts is None: + contexts = [doc[:]] + + inputs: List[List[List[int]]] = [] + total = 0 + + for ctx in contexts: + words = [] + for word in ctx: + subwords = [] + for subword in word.text[::4]: + subwords.append(total) + total += 1 + words.append(subwords) + inputs.append(words) + + return { + "inputs": inputs, # List[Context][Word] -> int + } + + def collate(self, batch): + # Flatten indices and keep per-(sample,context) lengths to refold later + inputs = ft.as_folded_tensor( + batch["inputs"], + data_dims=("sample", "token"), + full_names=("sample", "context", "word", "token"), + dtype=torch.long, + ) + item_indices = span_offsets = span_indices = None + if self.word_pooling_mode == "mean": + samples = torch.arange(max(inputs.lengths["sample"])) + words = torch.arange(max(inputs.lengths["word"])) + n_words = len(words) + n_samples = len(samples) + words = words[None, :].expand(n_samples, -1) + samples = samples[:, None].expand(-1, n_words) + words = words.masked_fill( + ~inputs.refold("sample", "word", "token").mask.any(-1), 0 + ) + item_indices, span_offsets, span_indices = ( + inputs.lengths.make_indices_ranges( + begins=(samples, words), + ends=(samples, words + 1), + indice_dims=( + "sample", + "word", + ), + return_tensors="pt", + ) + ) + span_offsets = ft.as_folded_tensor( + span_offsets, + data_dims=( + "sample", + "word", + ), + full_names=("sample", "context", "word"), + lengths=list(inputs.lengths)[0:-1], + ) + + return { + "out_structure": span_offsets.lengths + if self.word_pooling_mode == "mean" + else inputs.lengths, + "inputs": inputs, + "item_indices": item_indices, + "span_offsets": span_offsets, + "span_indices": span_indices, + } + + def forward(self, batch): + embeddings = ( + batch["inputs"] + .unsqueeze(-1) + .expand(-1, -1, self.output_size) + .to(torch.float32) + ) + print("shape before pool", embeddings.shape) + if self.word_pooling_mode == "mean": + embeddings = torch.nn.functional.embedding_bag( + embeddings.view(-1, self.output_size), + batch["item_indices"], + offsets=batch["span_offsets"].view(-1), + mode="max", + ) + embeddings = batch["span_offsets"].with_data( + embeddings.view(*batch["span_offsets"].shape, self.output_size) + ) + return { + "embeddings": embeddings, + } diff --git a/tests/pipelines/trainable/test_doc_classifier.py b/tests/pipelines/trainable/test_doc_classifier.py new file mode 100644 index 0000000000..22e6a0aef8 --- /dev/null +++ b/tests/pipelines/trainable/test_doc_classifier.py @@ -0,0 +1,33 @@ +import pytest + +import edsnlp +import edsnlp.pipes as eds + +pytest.importorskip("torch.nn") + + +@pytest.mark.parametrize("pooling_mode", ["mean", "max", "cls", "sum"]) +@pytest.mark.parametrize("label_attr", ["label", "alive"]) +@pytest.mark.parametrize("num_classes", [2, 10]) +def test_doc_classifier(pooling_mode, label_attr, num_classes): + nlp = edsnlp.blank("eds") + doc = nlp.make_doc("Le patient est mort.") + + nlp.add_pipe( + eds.doc_classifier( + embedding=eds.doc_pooler( + pooling_mode=pooling_mode, + embedding=eds.transformer( + model="prajjwal1/bert-tiny", + window=128, + stride=96, + ), + ), + num_classes=num_classes, + label_attr=label_attr, + ), + name="doc_classifier", + ) + doc = nlp(doc) + label = getattr(doc._, label_attr, None) + assert label in range(0, num_classes) diff --git a/tests/pipelines/trainable/test_span_pooler.py b/tests/pipelines/trainable/test_span_pooler.py new file mode 100644 index 0000000000..932a974342 --- /dev/null +++ b/tests/pipelines/trainable/test_span_pooler.py @@ -0,0 +1,277 @@ +import confit.utils.random +import pytest +from dummy_embeddings import DummyEmbeddings + +import edsnlp +import edsnlp.pipes as eds +from edsnlp.data.converters import MarkupToDocConverter +from edsnlp.pipes.trainable.embeddings.span_pooler.span_pooler import SpanPooler +from edsnlp.utils.collections import batch_compress_dict, decompress_dict + +pytest.importorskip("torch.nn") + +import torch + + +@pytest.mark.parametrize( + "word_pooling_mode,shape", + [ + ("mean", (2, 5, 2)), + (False, (2, 6, 2)), + ], +) +def test_dummy_embeddings(word_pooling_mode, shape): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("This is a sentence.") + doc2 = converter("A shorter one.") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + DummyEmbeddings(dim=2, word_pooling_mode=word_pooling_mode), name="embeddings" + ) + embedder: DummyEmbeddings = nlp.pipes.embeddings + + prep1 = embedder.preprocess(doc1) + prep2 = embedder.preprocess(doc2) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + batch = embedder.collate(pivoted_prep) + out = embedder.forward(batch)["embeddings"] + + assert out.shape == shape + + +@pytest.mark.parametrize("span_pooling_mode", ["max", "mean", "attention"]) +def test_span_pooler_on_words(span_pooling_mode): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("[This](ent) is [a sentence](ent). This is [small one](ent).") + doc2 = converter("An [even shorter one](ent) !") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.span_pooler( + embedding=DummyEmbeddings(dim=2), + pooling_mode=span_pooling_mode, + ) + ) + pooler: SpanPooler = nlp.pipes.span_pooler + + prep1 = pooler.preprocess(doc1, spans=doc1.ents) + prep2 = pooler.preprocess(doc2, spans=doc2.ents) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + batch = pooler.collate(pivoted_prep) + out = pooler.forward(batch)["embeddings"] + + assert out.shape == (4, 2) + out = out.refold("sample", "span") + + assert out.shape == (2, 3, 2) + if span_pooling_mode == "attention": + expected = [ + [[0.0000, 0.0000], [3.8102, 3.8102], [9.7554, 9.7554]], + [[3.6865, 3.6865], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + elif span_pooling_mode == "mean": + expected = [ + [[0.0000, 0.0000], [3.0000, 3.0000], [9.5000, 9.5000]], + [[2.6667, 2.6667], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + elif span_pooling_mode == "max": + expected = [ + [[0.0000, 0.0000], [4.0000, 4.0000], [10.0000, 10.0000]], + [[4.0000, 4.0000], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + else: + raise ValueError(f"Unknown pooling mode: {span_pooling_mode}") + assert torch.allclose(out, torch.tensor(expected), atol=1e-4) + + +@pytest.mark.parametrize("span_pooling_mode", ["max", "mean", "attention"]) +def test_span_pooler_on_tokens(span_pooling_mode): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("[This](ent) is [a sentence](ent). This is [small one](ent).") + doc2 = converter("An [even shorter one](ent) !") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.span_pooler( + embedding=DummyEmbeddings(dim=2, word_pooling_mode=False), + pooling_mode=span_pooling_mode, + ) + ) + pooler: SpanPooler = nlp.pipes.span_pooler + + prep1 = pooler.preprocess(doc1, spans=doc1.ents) + prep2 = pooler.preprocess(doc2, spans=doc2.ents) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + batch = pooler.collate(pivoted_prep) + out = pooler.forward(batch)["embeddings"] + + assert out.shape == (4, 2) + out = out.refold("sample", "span") + + assert out.shape == (2, 3, 2) + if span_pooling_mode == "attention": + expected = [ + [[0.0000, 0.0000], [3.6265, 3.6265], [9.6265, 9.6265]], + [[3.5655, 3.5655], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + elif span_pooling_mode == "mean": + expected = [ + [[0.0000, 0.0000], [3.0000, 3.0000], [9.0000, 9.0000]], + [[2.5000, 2.5000], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + elif span_pooling_mode == "max": + expected = [ + [[0.0000, 0.0000], [4.0000, 4.0000], [10.0000, 10.0000]], + [[4.0000, 4.0000], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + else: + raise ValueError(f"Unknown pooling mode: {span_pooling_mode}") + assert torch.allclose(out, torch.tensor(expected), atol=1e-4) + + +def test_span_pooler_on_flat_hf_tokens(): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("[This](ent) is [a sentence](ent). This is [small one](ent).") + doc2 = converter("An [even](ent) [shorter one](ent) !") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.span_pooler( + embedding=eds.transformer( + model="almanach/camembert-base", + word_pooling_mode=False, + ), + pooling_mode="mean", + ) + ) + pooler: SpanPooler = nlp.pipes.span_pooler + + prep1 = pooler.preprocess(doc1, spans=doc1.ents) + prep2 = pooler.preprocess(doc2, spans=doc2.ents) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + print( + nlp.pipes.span_pooler.embedding.tokenizer.convert_ids_to_tokens( + prep2["embedding"]["input_ids"][0] + ) + ) + # fmt: off + assert prep1["embedding"]["input_ids"] == [ + [ + 17526, # ▁This: 0 -> span 0 + 2856, # ▁is: 1 + 33, # ▁a: 2 -> span 1 + 22625, # ▁sentence: 3 -> span 1 + 9, # .: 4 + 17526, # ▁This: 5 + 2856, # ▁is: 6 + 52, # ▁s: 7 -> span 2 + 215, # m: 8 -> span 2 + 3645, # all: 9 -> span 2 + 91, # ▁on: 10 -> span 2 + 35, # e: 11 -> span 2 + 9, # .: 12 + ], + ] + # '▁An', '▁', 'even', '▁short', 'er', '▁on', 'e', '▁!' + assert prep2["embedding"]["input_ids"] == [ + [ + 2764, # ▁An: 13 + 21, # ▁: 14 + 15999, # even: 15 -> span 3 + 9161, # short: 16 -> span 4 + 108, # er: 17 -> span 4 + 91, # ▁on: 18 -> span 4 + 35, # e: 19 -> span 4 + 83, # ▁!: 20 + ] + ] + # fmt: on + batch = pooler.collate(pivoted_prep) + out = pooler.forward(batch)["embeddings"] + + word_embeddings = pooler.embedding(batch["embedding"])["embeddings"] + assert word_embeddings.shape == (21, 768) + + assert out.shape == (5, 768) + + # item_indices: [0, 2, 3, 7, 8, 9, 10, 11, 14, 15, 16, 17, 18, 19] + # - ---- --------------- ------ -------------- + # span_offsets: [0, 1, 3, 8, 10] + # span_indices: [0, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 4] + + assert torch.allclose(out[0], word_embeddings[0]) + assert torch.allclose(out[1], word_embeddings[2:4].mean(0)) + assert torch.allclose(out[2], word_embeddings[7:12].mean(0)) + assert torch.allclose(out[3], word_embeddings[14:16].mean(0)) + assert torch.allclose(out[4], word_embeddings[16:20].mean(0)) + + +def test_span_pooler_on_pooled_hf_tokens(): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("[This](ent) is [a sentence](ent). This is [small one](ent).") + doc2 = converter("An [even](ent) [shorter one](ent) !") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.span_pooler( + embedding=eds.transformer( + model="almanach/camembert-base", + word_pooling_mode="mean", + ), + pooling_mode="mean", + ) + ) + pooler: SpanPooler = nlp.pipes.span_pooler + + prep1 = pooler.preprocess(doc1, spans=doc1.ents) + prep2 = pooler.preprocess(doc2, spans=doc2.ents) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + print( + nlp.pipes.span_pooler.embedding.tokenizer.convert_ids_to_tokens( + prep2["embedding"]["input_ids"][0] + ) + ) + # fmt: off + assert prep1["embedding"]["input_ids"] == [ + [ + 17526, # ▁This: 0 -> span 0 + 2856, # ▁is: 1 + 33, # ▁a: 2 -> span 1 + 22625, # ▁sentence: 3 -> span 1 + 9, # .: 4 + 17526, # ▁This: 5 + 2856, # ▁is: 6 + 52, 215, 3645, # ▁s m all: 7 -> span 2 + 91, 35, # ▁on e: 8 -> span 2 + 9, # .: 9 + ], + ] + # '▁An', '▁', 'even', '▁short', 'er', '▁on', 'e', '▁!' + assert prep2["embedding"]["input_ids"] == [ + [ + 2764, # ▁An: 10 + 21, 15999, # ▁, even: 11 -> span 3 + 9161, 108, # short er: 12 -> span 4 + 91, 35, # ▁on e: 13 -> span 4 + 83, # ▁!: 14 + ] + ] + # fmt: on + batch = pooler.collate(pivoted_prep) + out = pooler.forward(batch)["embeddings"] + + word_embeddings = pooler.embedding(batch["embedding"])["embeddings"] + assert word_embeddings.shape == (15, 768) + + assert out.shape == (5, 768) + + # item_indices: [0, 2, 3, 7, 8, 11, 12, 13] + # - ---- ---- -- ------ + # span_offsets: [0, 1, 3, 5, 6 ] + + assert torch.allclose(out[0], word_embeddings[0]) + assert torch.allclose(out[1], word_embeddings[2:4].mean(0)) + assert torch.allclose(out[2], word_embeddings[7:9].mean(0)) + assert torch.allclose(out[3], word_embeddings[11]) + assert torch.allclose(out[4], word_embeddings[12:14].mean(0)) diff --git a/tests/pipelines/trainable/test_span_qualifier.py b/tests/pipelines/trainable/test_span_qualifier.py index 66a75abc65..19be040387 100644 --- a/tests/pipelines/trainable/test_span_qualifier.py +++ b/tests/pipelines/trainable/test_span_qualifier.py @@ -49,7 +49,10 @@ def gold(): @pytest.mark.parametrize("with_constraints_and_not_none", [True, False]) -def test_span_qualifier(gold, with_constraints_and_not_none, tmp_path): +@pytest.mark.parametrize("word_pooling_mode", ["mean", False]) +def test_span_qualifier( + gold, with_constraints_and_not_none, word_pooling_mode, tmp_path +): import torch nlp = edsnlp.blank("eds") @@ -60,6 +63,7 @@ def test_span_qualifier(gold, with_constraints_and_not_none, tmp_path): model="prajjwal1/bert-tiny", window=128, stride=96, + word_pooling_mode=word_pooling_mode, ), ) nlp.add_pipe( @@ -69,6 +73,8 @@ def test_span_qualifier(gold, with_constraints_and_not_none, tmp_path): "embedding": { "@factory": "eds.span_pooler", "embedding": nlp.get_pipe("transformer"), + "norm": "layernorm", + "activation": "relu", }, "span_getter": ["ents", "sc"], "qualifiers": {"_.test_negated": True, "_.event_type": ("event",)} diff --git a/tests/pipelines/trainable/test_transformer.py b/tests/pipelines/trainable/test_transformer.py index f2802c3342..d9d9fd6b9d 100644 --- a/tests/pipelines/trainable/test_transformer.py +++ b/tests/pipelines/trainable/test_transformer.py @@ -1,8 +1,11 @@ import pytest +from confit.utils.random import set_seed from pytest import fixture from spacy.tokens import Span import edsnlp +import edsnlp.pipes as eds +from edsnlp.data.converters import MarkupToDocConverter from edsnlp.utils.collections import batch_compress_dict, decompress_dict if not Span.has_extension("label"): @@ -89,4 +92,39 @@ def test_span_getter(gold): batch = trf.collate(batch) batch = trf.batch_to_device(batch, device=trf.device) res = trf(batch) - assert res["embeddings"].shape == (2, 5, 128) + assert res["embeddings"].shape == (9, 128) + + +def test_transformer_pooling(): + nlp = edsnlp.blank("eds") + converter = MarkupToDocConverter(tokenizer=nlp.tokenizer) + doc1 = converter("These are small sentencesstuff.") + doc2 = converter("A tiny one.") + + def run_trf(word_pooling_mode): + set_seed(42) + trf = eds.transformer( + model="prajjwal1/bert-tiny", + window=128, + stride=96, + word_pooling_mode=word_pooling_mode, + ) + prep1 = trf.preprocess(doc1) + prep2 = trf.preprocess(doc2) + assert prep1["input_ids"] == [ + [2122, 2024, 2235, 11746, 3367, 16093, 2546, 1012] + ] + assert prep2["input_ids"] == [[1037, 4714, 2028, 1012]] + batch = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + batch = trf.collate(batch) + return trf(batch) + + res_pool = run_trf(word_pooling_mode="mean") + assert res_pool["embeddings"].shape == (9, 128) + + res_flat = run_trf(word_pooling_mode=False) + assert res_flat["embeddings"].shape == (12, 128) + + # The second sequence is identical in both cases (only one subword per word) + # so the last 4 word/subword embeddings should be identical + assert res_pool["embeddings"][-4:].allclose(res_flat["embeddings"][-4:]) diff --git a/tests/training/ner_qlf_same_bert_config.yml b/tests/training/ner_qlf_same_bert_config.yml index 9d9aab8e96..060131ec2f 100644 --- a/tests/training/ner_qlf_same_bert_config.yml +++ b/tests/training/ner_qlf_same_bert_config.yml @@ -36,6 +36,7 @@ nlp: embedding: '@factory': eds.span_pooler + pooling_mode: attention embedding: # ${ nlp.components.ner.embedding } '@factory': eds.text_cnn