From 3bebda69f161a43987c73b9c46445caccb38bd78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Tue, 2 Sep 2025 09:10:30 +0200 Subject: [PATCH] feat: new attention span_pooling mode --- changelog.md | 8 + docs/tutorials/index.md | 6 +- edsnlp/core/torch_component.py | 40 ++- .../embeddings/span_pooler/span_pooler.py | 240 ++++++++++----- .../trainable/embeddings/text_cnn/text_cnn.py | 6 +- .../embeddings/transformer/transformer.py | 105 +++++-- edsnlp/training/trainer.py | 4 +- edsnlp/tune.py | 2 +- pyproject.toml | 2 +- tests/pipelines/trainable/dummy_embeddings.py | 123 ++++++++ tests/pipelines/trainable/test_span_pooler.py | 277 ++++++++++++++++++ .../trainable/test_span_qualifier.py | 8 +- tests/pipelines/trainable/test_transformer.py | 40 ++- tests/training/ner_qlf_same_bert_config.yml | 1 + 14 files changed, 742 insertions(+), 120 deletions(-) create mode 100644 tests/pipelines/trainable/dummy_embeddings.py create mode 100644 tests/pipelines/trainable/test_span_pooler.py diff --git a/changelog.md b/changelog.md index 92b294b1ae..8650169a6e 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,12 @@ # Changelog +## Unreleased + +### Added + +- New `attention` pooling mode in `eds.span_pooler` +- New `word_pooling_mode=False` in `eds.transformer` to allow returning the worpiece embeddings directly, instead of the mean-pooled word embeddings. At the moment, this only works with `eds.span_pooler` which can pool over wordpieces or words seamlessly. + ## v0.18.0 (2025-09-02) 📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.19.0), in October 2025. Please upgrade to Python 3.10 or later. @@ -13,6 +20,7 @@ - New `eds.explode` pipe that splits one document into multiple documents, one per span yielded by its `span_getter` parameter, each new document containing exactly that single span. - New `Training a span classifier` tutorial, and reorganized deep-learning docs - `ScheduledOptimizer` now warns when a parameter selector does not match any parameter. +- New `attention` pooling mode in `eds.span_pooler` ### Fixed diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index 9f563e8fce..9d20e11042 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -4,6 +4,7 @@ We provide step-by-step guides to get you started. We cover the following use-ca ### Base tutorials + === card {: href=/tutorials/spacy101 } @@ -85,6 +86,8 @@ We provide step-by-step guides to get you started. We cover the following use-ca --- Quickly visualize the results of your pipeline as annotations or tables. + + ### Deep learning tutorials We also provide tutorials on how to train deep-learning models with EDS-NLP. These tutorials cover the training API, hyperparameter tuning, and more. @@ -123,8 +126,5 @@ We also provide tutorials on how to train deep-learning models with EDS-NLP. The --- Learn how to tune hyperparameters of a model with `edsnlp.tune`. - - - diff --git a/edsnlp/core/torch_component.py b/edsnlp/core/torch_component.py index e78b4fc6b6..0d50538a36 100644 --- a/edsnlp/core/torch_component.py +++ b/edsnlp/core/torch_component.py @@ -339,7 +339,14 @@ def compute_training_metrics( This is useful to compute averages when doing multi-gpu training or mini-batch accumulation since full denominators are not known during the forward pass. """ - return batch_output + return ( + { + **batch_output, + "loss": batch_output["loss"] / count, + } + if "loss" in batch_output + else batch_output + ) def module_forward(self, *args, **kwargs): # pragma: no cover """ @@ -348,6 +355,31 @@ def module_forward(self, *args, **kwargs): # pragma: no cover """ return torch.nn.Module.__call__(self, *args, **kwargs) + def preprocess_batch(self, docs: Sequence[Doc], supervision=False, **kwargs): + """ + Convenience method to preprocess a batch of documents. + Features corresponding to the same path are grouped together in a list, + under the same key. + + Parameters + ---------- + docs: Sequence[Doc] + Batch of documents + supervision: bool + Whether to extract supervision features or not + + Returns + ------- + Dict[str, Sequence[Any]] + The batch of features + """ + batch = [ + (self.preprocess_supervised(d) if supervision else self.preprocess(d)) + for d in docs + ] + batch = decompress_dict(list(batch_compress_dict(batch))) + return batch + def prepare_batch( self, docs: Sequence[Doc], @@ -372,11 +404,7 @@ def prepare_batch( ------- Dict[str, Sequence[Any]] """ - batch = [ - (self.preprocess_supervised(doc) if supervision else self.preprocess(doc)) - for doc in docs - ] - batch = decompress_dict(list(batch_compress_dict(batch))) + batch = self.preprocess_batch(docs, supervision=supervision) batch = self.collate(batch) batch = self.batch_to_device(batch, device=device) return batch diff --git a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py index 5e58cd9bb6..3032c574f8 100644 --- a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py +++ b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py @@ -22,21 +22,34 @@ "SpanPoolerBatchInput", { "embedding": BatchInput, - "begins": ft.FoldedTensor, - "ends": ft.FoldedTensor, - "sequence_idx": torch.Tensor, - "stats": TypedDict("SpanPoolerBatchStats", {"spans": int}), + "span_begins": ft.FoldedTensor, + "span_ends": ft.FoldedTensor, + "span_contexts": ft.FoldedTensor, + "item_indices": torch.LongTensor, + "span_offsets": torch.LongTensor, + "span_indices": torch.LongTensor, + "stats": Dict[str, int], }, ) """ -embeds: torch.FloatTensor - Token embeddings to predict the tags from -begins: torch.LongTensor +Attributes +---------- +embedding: BatchInput + The input batch for the word embedding component +span_begins: ft.FoldedTensor Begin offsets of the spans -ends: torch.LongTensor +span_ends: ft.FoldedTensor End offsets of the spans -sequence_idx: torch.LongTensor - Sequence (cf Embedding spans) index of the spans +span_contexts: ft.FoldedTensor + Sequence/context index of the spans +item_indices: torch.LongTensor + Indices of the span's tokens in the tokens embedding output +span_offsets: torch.LongTensor + Offsets of the spans in the flattened span tokens +span_indices: torch.LongTensor + Span index of each token in the flattened span tokens +stats: Dict[str, int] + Statistics about the batch, e.g. number of spans """ SpanPoolerBatchOutput = TypedDict( @@ -45,6 +58,12 @@ "embeddings": ft.FoldedTensor, }, ) +""" +Attributes +---------- +embeddings: ft.FoldedTensor + The output span embeddings, with foldable dimensions ("sample", "span") +""" class SpanPooler(SpanEmbeddingComponent, BaseComponent): @@ -61,8 +80,14 @@ class SpanPooler(SpanEmbeddingComponent, BaseComponent): Name of the component embedding : WordEmbeddingComponent The word embedding component - pooling_mode: Literal["max", "sum", "mean"] - How word embeddings are aggregated into a single embedding per span. + pooling_mode: Literal["max", "sum", "mean", "attention"] + How word embeddings are aggregated into a single embedding per span: + + - "max": max pooling + - "sum": sum pooling + - "mean": mean pooling + - "attention": attention pooling, where attention scores are computed using a + linear layer followed by a softmax over the tokens in the span. hidden_size : Optional[int] The size of the hidden layer. If None, no projection is done and the output of the span pooler is used directly. @@ -74,7 +99,9 @@ def __init__( name: str = "span_pooler", *, embedding: WordEmbeddingComponent, - pooling_mode: Literal["max", "sum", "mean"] = "mean", + pooling_mode: Literal["max", "sum", "mean", "attention"] = "mean", + activation: Optional[Literal["relu", "gelu", "silu"]] = None, + norm: Optional[Literal["layernorm", "batchnorm"]] = None, hidden_size: Optional[int] = None, span_getter: Any = None, ): @@ -99,11 +126,35 @@ def __init__( self.pooling_mode = pooling_mode self.span_getter = span_getter self.embedding = embedding - self.projector = ( - torch.nn.Linear(self.embedding.output_size, hidden_size) - if hidden_size is not None - else torch.nn.Identity() - ) + self.activation = activation + self.projector = torch.nn.Sequential() + if hidden_size is not None: + self.projector.append( + torch.nn.Linear(self.embedding.output_size, hidden_size) + ) + if activation is not None: + self.projector.append( + { + "relu": torch.nn.ReLU, + "gelu": torch.nn.GELU, + "silu": torch.nn.SiLU, + }[activation]() + ) + if norm is not None: + self.projector.append( + { + "layernorm": torch.nn.LayerNorm, + "batchnorm": torch.nn.BatchNorm1d, + }[norm]( + hidden_size + if hidden_size is not None + else self.embedding.output_size + ) + ) + if self.pooling_mode in {"attention"}: + self.attention_scorer = torch.nn.Linear( + self.embedding.output_size, 1, bias=False + ) def feed_forward(self, span_embeds: torch.Tensor) -> torch.Tensor: return self.projector(span_embeds) @@ -112,18 +163,23 @@ def preprocess( self, doc: Doc, *, - spans: Optional[Sequence[Span]] = None, + spans: Optional[Sequence[Span]], contexts: Optional[Sequence[Span]] = None, pre_aligned: bool = False, **kwargs, ) -> Dict[str, Any]: - contexts = contexts if contexts is not None else [doc[:]] + if contexts is None: + contexts = [doc[:]] * len(spans) + pre_aligned = True - sequence_idx = [] + context_indices = [] begins = [] ends = [] - contexts_to_idx = {span: i for i, span in enumerate(contexts)} + contexts_to_idx = {} + for ctx in contexts: + if ctx not in contexts_to_idx: + contexts_to_idx[ctx] = len(contexts_to_idx) assert not pre_aligned or len(spans) == len(contexts), ( "When `pre_aligned` is True, the number of spans and contexts must be the " "same." @@ -140,52 +196,96 @@ def preprocess( f"span: {[s.text for s in ctx]}" ) start = ctx[0].start - sequence_idx.append(contexts_to_idx[ctx[0]]) + context_indices.append(contexts_to_idx[ctx[0]]) begins.append(span.start - start) ends.append(span.end - start) return { "begins": begins, "ends": ends, - "sequence_idx": sequence_idx, + "span_to_ctx_idx": context_indices, "num_sequences": len(contexts), - "embedding": self.embedding.preprocess(doc, contexts=contexts, **kwargs), + "embedding": self.embedding.preprocess( + doc, contexts=list(contexts_to_idx), **kwargs + ), "stats": {"spans": len(begins)}, } def collate(self, batch: Dict[str, Sequence[Any]]) -> SpanPoolerBatchInput: - sequence_idx = [] - offset = 0 - for indices, seq_length in zip(batch["sequence_idx"], batch["num_sequences"]): - sequence_idx.extend([offset + idx for idx in indices]) - offset += seq_length + embedding_batch = self.embedding.collate(batch["embedding"]) + embed_structure = embedding_batch["out_structure"] + ft_kw = dict( + data_dims=("span",), + full_names=("sample", "span"), + dtype=torch.long, + ) + begins = ft.as_folded_tensor(batch["begins"], **ft_kw) + ends = ft.as_folded_tensor(batch["ends"], **ft_kw) + span_to_ctx_idx = [] + total_num_ctx = 0 + for i, (ctx_indices, num_ctx) in enumerate( + zip(batch["span_to_ctx_idx"], embed_structure["context"]) + ): + span_to_ctx_idx.append([idx + total_num_ctx for idx in ctx_indices]) + total_num_ctx += num_ctx + flat_span_to_ctx_idx = ft.as_folded_tensor(span_to_ctx_idx, **ft_kw) + item_indices, span_offsets, span_indices = embed_structure.make_indices_ranges( + begins=(flat_span_to_ctx_idx, begins), + ends=(flat_span_to_ctx_idx, ends), + indice_dims=("context", "word"), + ) collated: SpanPoolerBatchInput = { - "embedding": self.embedding.collate(batch["embedding"]), - "begins": ft.as_folded_tensor( - batch["begins"], - data_dims=("span",), - full_names=("sample", "span"), - dtype=torch.long, - ), - "ends": ft.as_folded_tensor( - batch["ends"], - data_dims=("span",), - full_names=("sample", "span"), - dtype=torch.long, - ), - "sequence_idx": torch.as_tensor(sequence_idx), + "embedding": embedding_batch, + "span_begins": begins, + "span_ends": ends, + "span_contexts": flat_span_to_ctx_idx, + "item_indices": item_indices, + "span_offsets": begins.with_data(span_offsets), + "span_indices": span_indices, "stats": {"spans": sum(batch["stats"]["spans"])}, } return collated + def _pool_spans(self, embeds, span_indices, span_offsets, item_indices=None): + dev = span_offsets.device + dim = embeds.size(-1) + embeds = embeds.as_tensor().view(-1, dim) + N = span_offsets.numel() # number of spans + + if self.pooling_mode == "attention": + if item_indices is not None: + embeds = embeds[item_indices] + weights = self.attention_scorer(embeds) + # compute max for softmax stability + dtype = weights.dtype + max_weights = torch.full((N, 1), float("-inf"), device=dev, dtype=dtype) + max_weights.index_reduce_(0, span_indices, weights, reduce="amax") + # softmax numerator + exp_scores = torch.exp(weights - max_weights[span_indices]) + # softmax denominator + denom = torch.zeros((N, 1), device=dev, dtype=exp_scores.dtype) + denom.index_add_(0, span_indices, exp_scores) + # softmax output = embeds * weight num / weight denom + weighted_embeds = embeds * exp_scores / denom[span_indices] + span_embeds = torch.zeros((N, dim), device=dev, dtype=embeds.dtype) + span_embeds.index_add_(0, span_indices, weighted_embeds) + span_embeds = span_offsets.with_data(span_embeds) + else: + span_embeds = torch.nn.functional.embedding_bag( # type: ignore + input=torch.arange(len(embeds), device=dev) + if item_indices is None + else item_indices, + weight=embeds, + offsets=span_offsets, + mode=self.pooling_mode, + ) + span_embeds = self.feed_forward(span_embeds) + return span_embeds + # noinspection SpellCheckingInspection def forward(self, batch: SpanPoolerBatchInput) -> SpanPoolerBatchOutput: """ - Apply the span classifier module to the document embeddings and given spans to: - - compute the loss - - and/or predict the labels of spans - If labels are predicted, they are assigned to the `additional_outputs` - dictionary. + Forward pass of the component, returns span embeddings. Parameters ---------- @@ -194,38 +294,26 @@ def forward(self, batch: SpanPoolerBatchInput) -> SpanPoolerBatchOutput: Returns ------- - BatchOutput + SpanPoolerBatchOutput """ - device = next(self.parameters()).device - if len(batch["begins"]) == 0: - span_embeds = torch.empty(0, self.output_size, device=device) + if len(batch["span_begins"]) == 0: return { - "embeddings": batch["begins"].with_data(span_embeds), + "embeddings": batch["span_begins"].with_data( + torch.empty( + 0, + self.output_size, + device=batch["item_indices"].device, + ) + ), } embeds = self.embedding(batch["embedding"])["embeddings"] - _, n_words, dim = embeds.shape - device = embeds.device - - flat_begins = n_words * batch["sequence_idx"] + batch["begins"].as_tensor() - flat_ends = n_words * batch["sequence_idx"] + batch["ends"].as_tensor() - flat_embeds = embeds.view(-1, dim) - flat_indices = torch.cat( - [ - torch.arange(b, e, device=device) - for b, e in zip(flat_begins.cpu().tolist(), flat_ends.cpu().tolist()) - ] - ).to(device) - offsets = (flat_ends - flat_begins).cumsum(0).roll(1) - offsets[0] = 0 - span_embeds = torch.nn.functional.embedding_bag( # type: ignore - input=flat_indices, - weight=flat_embeds, - offsets=offsets, - mode=self.pooling_mode, + span_embeds = self._pool_spans( + embeds, + span_indices=batch["span_indices"], + span_offsets=batch["span_offsets"], + item_indices=batch["item_indices"], ) - span_embeds = self.feed_forward(span_embeds) - return { - "embeddings": batch["begins"].with_data(span_embeds), + "embeddings": batch["span_begins"].with_data(span_embeds), } diff --git a/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py b/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py index c1de49a562..92fff74928 100644 --- a/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py +++ b/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence +from typing import Any, Dict, Optional, Sequence import torch from typing_extensions import Literal, TypedDict @@ -87,6 +87,10 @@ def __init__( normalize=normalize, ) + def collate(self, batch: Dict[str, Any]) -> BatchInput: + emb = self.embedding.collate(batch["embedding"]) + return {"embedding": emb, "out_structure": emb["out_structure"]} + def forward(self, batch: BatchInput) -> WordEmbeddingBatchOutput: """ Encode embeddings with a 1d convolutional network diff --git a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py index 4d830858e8..74205369a8 100644 --- a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py +++ b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py @@ -34,6 +34,8 @@ }, ) """ +Attributes +---------- input_ids: FoldedTensor Tokenized input (prompt + text) to embed word_indices: torch.LongTensor @@ -51,6 +53,8 @@ }, ) """ +Attributes +---------- embeddings: FoldedTensor The embeddings of the words """ @@ -101,9 +105,23 @@ class Transformer(WordEmbeddingComponent[TransformerBatchInput]): stride=96, ), ) + + doc1 = nlp.make_doc("My name is Michael.") + doc2 = nlp.make_doc("And I am the best boss in the world.") + prep = nlp.pipes.transformer.preprocess_batch([doc1, doc2]) + batch = nlp.pipes.transformer.collate(prep) + res = nlp.pipes.transformer(batch) + + # Embeddings are flattened by default + print(res["embeddings"].shape) + # Out: torch.Size([15, 128]) + + # But they can be refolded to materialize the sample dimension + print(res["embeddings"].refold("sample", "word").shape) + # Out: torch.Size([2, 10, 128]) ``` - You can then compose this embedding with a task specific component such as + You can compose this embedding with a task specific component such as `eds.ner_crf`. Parameters @@ -131,9 +149,25 @@ class Transformer(WordEmbeddingComponent[TransformerBatchInput]): If "auto", the component will try to estimate the maximum number of tokens that can be processed by the model on the current device at a given time. - span_getter: Optional[SpanGetterArg] - Which spans of the document should be embedded. Defaults to the full document - if None. + new_tokens: Optional[List[Tuple[str, str]]] + A list of (pattern, replacement) tuples to add to the tokenizer. The pattern + should be a valid regular expression. The replacement should be a string. + + This can be used to add new tokens to the tokenizer that are not present in the + original vocabulary. For example, if you want to add a new token for new lines + you can use the following: + + ```python + new_tokens = [("\\n", "⏎")] + ``` + quantization: Optional[BitsAndBytesConfig] + Quantization configuration to use for the model. If None, no quantization + will be applied. This requires the `bitsandbytes` library to be installed. + word_pooling_mode: Literal["mean", False] + If "mean", the embeddings of the wordpieces corresponding to each word will be + averaged to produce a single embedding per word. If False, the embeddings of the + wordpieces will be returned as a FoldedTensor with an additional "token" + dimension. (default: "mean") """ def __init__( @@ -149,6 +183,7 @@ def __init__( span_getter: Optional[SpanGetterArg] = None, new_tokens: Optional[List[Tuple[str, str]]] = [], quantization: Optional[BitsAndBytesConfig] = None, + word_pooling_mode: Literal["mean", False] = "mean", **kwargs, ): super().__init__(nlp, name) @@ -167,6 +202,7 @@ def __init__( kwargs["quantization_config"] = quantization self.transformer = AutoModel.from_pretrained(model, **kwargs) + self.word_pooling_mode = word_pooling_mode try: self.tokenizer = AutoTokenizer.from_pretrained(model) except (HTTPException, ConnectionError): # pragma: no cover @@ -353,6 +389,7 @@ def collate(self, batch): empty_word_indices = [] overlap = self.window - stride word_offset = 0 + word_sizes = [] all_word_wp_offset = 0 for ( sample_text_input_ids, @@ -422,23 +459,38 @@ def collate(self, batch): ] ] word_indices.extend(word_wp_indices) + word_sizes.append(length) word_wp_offset += length word_offset += 1 all_word_wp_offset += word_wp_offset + word_offsets = ft.as_folded_tensor( + word_offsets, + data_dims=("word",), + full_names=("sample", "context", "word"), + dtype=torch.long, + ) + out_structure = ( + ft.FoldedTensorLayout( + [ + *word_offsets.lengths, + word_sizes, + ], + full_names=("sample", "context", "word", "token"), + data_dims=("token",), + ) + if not self.word_pooling_mode + else word_offsets.lengths + ) return { + "out_structure": out_structure, "input_ids": ft.as_folded_tensor( input_ids, data_dims=("context", "subword"), full_names=("context", "subword"), dtype=torch.long, ), - "word_offsets": ft.as_folded_tensor( - word_offsets, - data_dims=("word",), - full_names=("sample", "context", "word"), - dtype=torch.long, - ), + "word_offsets": word_offsets, "word_indices": torch.as_tensor(word_indices, dtype=torch.long), "empty_word_indices": torch.as_tensor(empty_word_indices, dtype=torch.long), "stats": { @@ -480,7 +532,7 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput: max_windows = max(1, max_tokens // input_ids.size(1)) total_windows = input_ids.size(0) try: - wordpiece_embeddings = [ + wp_embs = [ self.transformer.base_model( **{ k: None if v is None else v[offset : offset + max_windows] @@ -490,11 +542,7 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput: for offset in range(0, total_windows, max_windows) ] - wordpiece_embeddings = ( - torch.cat(wordpiece_embeddings, dim=0) - if len(wordpiece_embeddings) > 1 - else wordpiece_embeddings[0] - ) + wp_embs = torch.cat(wp_embs, dim=0) if len(wp_embs) > 1 else wp_embs[0] if auto_batch_size: # pragma: no cover batch_mem = torch.cuda.max_memory_allocated(device) @@ -520,26 +568,29 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput: # mask = batch["mask"].clone() # word_embeddings = torch.zeros( - # (mask.size(0), mask.size(1), wordpiece_embeddings.size(2)), + # (mask.size(0), mask.size(1), wp_embs.size(2)), # dtype=torch.float, # device=device, # ) # embeddings_plus_empty = torch.cat( # [ - # wordpiece_embeddings.view(-1, wordpiece_embeddings.size(2)), + # wp_embs.view(-1, wp_embs.size(2)), # self.empty_word_embedding, # ], # dim=0, # ) - word_embeddings = torch.nn.functional.embedding_bag( - input=batch["word_indices"], - weight=wordpiece_embeddings.reshape(-1, wordpiece_embeddings.size(2)), - offsets=batch["word_offsets"], - ) - word_embeddings[batch["empty_word_indices"]] = self.empty_word_embedding - return { - "embeddings": word_embeddings.refold("context", "word"), - } + if self.word_pooling_mode == "mean": + word_embeddings = torch.nn.functional.embedding_bag( + input=batch["word_indices"], + weight=wp_embs.reshape(-1, wp_embs.size(2)), + offsets=batch["word_offsets"], + ) + word_embeddings[batch["empty_word_indices"]] = self.empty_word_embedding + return {"embeddings": word_embeddings} + else: + wp_embs = wp_embs.reshape(-1, self.output_size)[batch["word_indices"]] + wp_embs = ft.as_folded_tensor(wp_embs, lengths=batch["out_structure"]) + return {"embeddings": wp_embs} @staticmethod def align_words_with_trf_tokens(doc, trf_char_indices): diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index 634e2296ab..c480e12b28 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -849,9 +849,7 @@ def train( for idx, (batch, batch_pipe_names) in enumerate( zip(batches, batches_pipe_names) ): - cache_ctx = ( - nlp.cache() if len(batch_pipe_names) > 1 else nullcontext() - ) + cache_ctx = nlp.cache() no_sync_ctx = ( accelerator.no_sync(trained_pipes) if idx < len(batches) - 1 diff --git a/edsnlp/tune.py b/edsnlp/tune.py index c7c46d6810..0f56e0194c 100644 --- a/edsnlp/tune.py +++ b/edsnlp/tune.py @@ -595,7 +595,7 @@ def compute_remaining_n_trials_possible( remaining_gpu_time, compute_time_per_trial(study, ema=True) ) return n_trials - except ValueError: + except ValueError: # pragma: no cover return 0 diff --git a/pyproject.toml b/pyproject.toml index 1ccd9df4bf..465d996e94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ docs-no-ml = [ ml = [ "rich-logger>=0.3.1", "torch>=1.13.0; python_version>='3.9'", - "foldedtensor>=0.4.0", + "foldedtensor @ git+https://github.com/aphp/foldedtensor.git@indice-mapping#egg=foldedtensor", "safetensors>=0.3.0; python_version>='3.8'", "safetensors>=0.3.0,<0.5.0; python_version<'3.8'", "transformers>=4.0.0", diff --git a/tests/pipelines/trainable/dummy_embeddings.py b/tests/pipelines/trainable/dummy_embeddings.py new file mode 100644 index 0000000000..c1cd3d3fd1 --- /dev/null +++ b/tests/pipelines/trainable/dummy_embeddings.py @@ -0,0 +1,123 @@ +from typing import List, Optional + +import pytest + +pytest.importorskip("torch") + +import foldedtensor as ft +import torch +from typing_extensions import Literal + +from edsnlp import Pipeline +from edsnlp.pipes.trainable.embeddings.typing import WordEmbeddingComponent + + +class DummyEmbeddings(WordEmbeddingComponent[dict]): + """ + For each word, embedding = (word idx in sent) * [1, 1, ..., 1] (size = dim) + """ + + def __init__( + self, + nlp: Optional[Pipeline] = None, + name: str = "fixed_embeddings", + word_pooling_mode: Literal["mean", False] = "mean", + *, + dim: int, + ): + super().__init__(nlp, name) + self.output_size = int(dim) + self.word_pooling_mode = word_pooling_mode + + def preprocess(self, doc, *, contexts=None, prompts=()): + if contexts is None: + contexts = [doc[:]] + + inputs: List[List[List[int]]] = [] + total = 0 + + for ctx in contexts: + words = [] + for word in ctx: + subwords = [] + for subword in word.text[::4]: + subwords.append(total) + total += 1 + words.append(subwords) + inputs.append(words) + + return { + "inputs": inputs, # List[Context][Word] -> int + } + + def collate(self, batch): + # Flatten indices and keep per-(sample,context) lengths to refold later + inputs = ft.as_folded_tensor( + batch["inputs"], + data_dims=("sample", "token"), + full_names=("sample", "context", "word", "token"), + dtype=torch.long, + ) + item_indices = span_offsets = span_indices = None + if self.word_pooling_mode == "mean": + samples = torch.arange(max(inputs.lengths["sample"])) + words = torch.arange(max(inputs.lengths["word"])) + n_words = len(words) + n_samples = len(samples) + words = words[None, :].expand(n_samples, -1) + samples = samples[:, None].expand(-1, n_words) + words = words.masked_fill( + ~inputs.refold("sample", "word", "token").mask.any(-1), 0 + ) + item_indices, span_offsets, span_indices = ( + inputs.lengths.make_indices_ranges( + begins=(samples, words), + ends=(samples, words + 1), + indice_dims=( + "sample", + "word", + ), + return_tensors="pt", + ) + ) + span_offsets = ft.as_folded_tensor( + span_offsets, + data_dims=( + "sample", + "word", + ), + full_names=("sample", "context", "word"), + lengths=list(inputs.lengths)[0:-1], + ) + + return { + "out_structure": span_offsets.lengths + if self.word_pooling_mode == "mean" + else inputs.lengths, + "inputs": inputs, + "item_indices": item_indices, + "span_offsets": span_offsets, + "span_indices": span_indices, + } + + def forward(self, batch): + embeddings = ( + batch["inputs"] + .unsqueeze(-1) + .expand(-1, -1, self.output_size) + .to(torch.float32) + ) + print("shape before pool", embeddings.shape) + if self.word_pooling_mode == "mean": + embeddings = torch.nn.functional.embedding_bag( + embeddings.view(-1, self.output_size), + batch["item_indices"], + offsets=batch["span_offsets"].view(-1), + mode="max", + ) + embeddings = batch["span_offsets"].with_data( + embeddings.view(*batch["span_offsets"].shape, self.output_size) + ) + return { + "embeddings": embeddings, + } diff --git a/tests/pipelines/trainable/test_span_pooler.py b/tests/pipelines/trainable/test_span_pooler.py new file mode 100644 index 0000000000..932a974342 --- /dev/null +++ b/tests/pipelines/trainable/test_span_pooler.py @@ -0,0 +1,277 @@ +import confit.utils.random +import pytest +from dummy_embeddings import DummyEmbeddings + +import edsnlp +import edsnlp.pipes as eds +from edsnlp.data.converters import MarkupToDocConverter +from edsnlp.pipes.trainable.embeddings.span_pooler.span_pooler import SpanPooler +from edsnlp.utils.collections import batch_compress_dict, decompress_dict + +pytest.importorskip("torch.nn") + +import torch + + +@pytest.mark.parametrize( + "word_pooling_mode,shape", + [ + ("mean", (2, 5, 2)), + (False, (2, 6, 2)), + ], +) +def test_dummy_embeddings(word_pooling_mode, shape): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("This is a sentence.") + doc2 = converter("A shorter one.") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + DummyEmbeddings(dim=2, word_pooling_mode=word_pooling_mode), name="embeddings" + ) + embedder: DummyEmbeddings = nlp.pipes.embeddings + + prep1 = embedder.preprocess(doc1) + prep2 = embedder.preprocess(doc2) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + batch = embedder.collate(pivoted_prep) + out = embedder.forward(batch)["embeddings"] + + assert out.shape == shape + + +@pytest.mark.parametrize("span_pooling_mode", ["max", "mean", "attention"]) +def test_span_pooler_on_words(span_pooling_mode): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("[This](ent) is [a sentence](ent). This is [small one](ent).") + doc2 = converter("An [even shorter one](ent) !") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.span_pooler( + embedding=DummyEmbeddings(dim=2), + pooling_mode=span_pooling_mode, + ) + ) + pooler: SpanPooler = nlp.pipes.span_pooler + + prep1 = pooler.preprocess(doc1, spans=doc1.ents) + prep2 = pooler.preprocess(doc2, spans=doc2.ents) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + batch = pooler.collate(pivoted_prep) + out = pooler.forward(batch)["embeddings"] + + assert out.shape == (4, 2) + out = out.refold("sample", "span") + + assert out.shape == (2, 3, 2) + if span_pooling_mode == "attention": + expected = [ + [[0.0000, 0.0000], [3.8102, 3.8102], [9.7554, 9.7554]], + [[3.6865, 3.6865], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + elif span_pooling_mode == "mean": + expected = [ + [[0.0000, 0.0000], [3.0000, 3.0000], [9.5000, 9.5000]], + [[2.6667, 2.6667], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + elif span_pooling_mode == "max": + expected = [ + [[0.0000, 0.0000], [4.0000, 4.0000], [10.0000, 10.0000]], + [[4.0000, 4.0000], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + else: + raise ValueError(f"Unknown pooling mode: {span_pooling_mode}") + assert torch.allclose(out, torch.tensor(expected), atol=1e-4) + + +@pytest.mark.parametrize("span_pooling_mode", ["max", "mean", "attention"]) +def test_span_pooler_on_tokens(span_pooling_mode): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("[This](ent) is [a sentence](ent). This is [small one](ent).") + doc2 = converter("An [even shorter one](ent) !") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.span_pooler( + embedding=DummyEmbeddings(dim=2, word_pooling_mode=False), + pooling_mode=span_pooling_mode, + ) + ) + pooler: SpanPooler = nlp.pipes.span_pooler + + prep1 = pooler.preprocess(doc1, spans=doc1.ents) + prep2 = pooler.preprocess(doc2, spans=doc2.ents) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + batch = pooler.collate(pivoted_prep) + out = pooler.forward(batch)["embeddings"] + + assert out.shape == (4, 2) + out = out.refold("sample", "span") + + assert out.shape == (2, 3, 2) + if span_pooling_mode == "attention": + expected = [ + [[0.0000, 0.0000], [3.6265, 3.6265], [9.6265, 9.6265]], + [[3.5655, 3.5655], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + elif span_pooling_mode == "mean": + expected = [ + [[0.0000, 0.0000], [3.0000, 3.0000], [9.0000, 9.0000]], + [[2.5000, 2.5000], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + elif span_pooling_mode == "max": + expected = [ + [[0.0000, 0.0000], [4.0000, 4.0000], [10.0000, 10.0000]], + [[4.0000, 4.0000], [0.0000, 0.0000], [0.0000, 0.0000]], + ] + else: + raise ValueError(f"Unknown pooling mode: {span_pooling_mode}") + assert torch.allclose(out, torch.tensor(expected), atol=1e-4) + + +def test_span_pooler_on_flat_hf_tokens(): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("[This](ent) is [a sentence](ent). This is [small one](ent).") + doc2 = converter("An [even](ent) [shorter one](ent) !") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.span_pooler( + embedding=eds.transformer( + model="almanach/camembert-base", + word_pooling_mode=False, + ), + pooling_mode="mean", + ) + ) + pooler: SpanPooler = nlp.pipes.span_pooler + + prep1 = pooler.preprocess(doc1, spans=doc1.ents) + prep2 = pooler.preprocess(doc2, spans=doc2.ents) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + print( + nlp.pipes.span_pooler.embedding.tokenizer.convert_ids_to_tokens( + prep2["embedding"]["input_ids"][0] + ) + ) + # fmt: off + assert prep1["embedding"]["input_ids"] == [ + [ + 17526, # ▁This: 0 -> span 0 + 2856, # ▁is: 1 + 33, # ▁a: 2 -> span 1 + 22625, # ▁sentence: 3 -> span 1 + 9, # .: 4 + 17526, # ▁This: 5 + 2856, # ▁is: 6 + 52, # ▁s: 7 -> span 2 + 215, # m: 8 -> span 2 + 3645, # all: 9 -> span 2 + 91, # ▁on: 10 -> span 2 + 35, # e: 11 -> span 2 + 9, # .: 12 + ], + ] + # '▁An', '▁', 'even', '▁short', 'er', '▁on', 'e', '▁!' + assert prep2["embedding"]["input_ids"] == [ + [ + 2764, # ▁An: 13 + 21, # ▁: 14 + 15999, # even: 15 -> span 3 + 9161, # short: 16 -> span 4 + 108, # er: 17 -> span 4 + 91, # ▁on: 18 -> span 4 + 35, # e: 19 -> span 4 + 83, # ▁!: 20 + ] + ] + # fmt: on + batch = pooler.collate(pivoted_prep) + out = pooler.forward(batch)["embeddings"] + + word_embeddings = pooler.embedding(batch["embedding"])["embeddings"] + assert word_embeddings.shape == (21, 768) + + assert out.shape == (5, 768) + + # item_indices: [0, 2, 3, 7, 8, 9, 10, 11, 14, 15, 16, 17, 18, 19] + # - ---- --------------- ------ -------------- + # span_offsets: [0, 1, 3, 8, 10] + # span_indices: [0, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 4] + + assert torch.allclose(out[0], word_embeddings[0]) + assert torch.allclose(out[1], word_embeddings[2:4].mean(0)) + assert torch.allclose(out[2], word_embeddings[7:12].mean(0)) + assert torch.allclose(out[3], word_embeddings[14:16].mean(0)) + assert torch.allclose(out[4], word_embeddings[16:20].mean(0)) + + +def test_span_pooler_on_pooled_hf_tokens(): + confit.utils.random.set_seed(42) + converter = MarkupToDocConverter() + doc1 = converter("[This](ent) is [a sentence](ent). This is [small one](ent).") + doc2 = converter("An [even](ent) [shorter one](ent) !") + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.span_pooler( + embedding=eds.transformer( + model="almanach/camembert-base", + word_pooling_mode="mean", + ), + pooling_mode="mean", + ) + ) + pooler: SpanPooler = nlp.pipes.span_pooler + + prep1 = pooler.preprocess(doc1, spans=doc1.ents) + prep2 = pooler.preprocess(doc2, spans=doc2.ents) + pivoted_prep = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + print( + nlp.pipes.span_pooler.embedding.tokenizer.convert_ids_to_tokens( + prep2["embedding"]["input_ids"][0] + ) + ) + # fmt: off + assert prep1["embedding"]["input_ids"] == [ + [ + 17526, # ▁This: 0 -> span 0 + 2856, # ▁is: 1 + 33, # ▁a: 2 -> span 1 + 22625, # ▁sentence: 3 -> span 1 + 9, # .: 4 + 17526, # ▁This: 5 + 2856, # ▁is: 6 + 52, 215, 3645, # ▁s m all: 7 -> span 2 + 91, 35, # ▁on e: 8 -> span 2 + 9, # .: 9 + ], + ] + # '▁An', '▁', 'even', '▁short', 'er', '▁on', 'e', '▁!' + assert prep2["embedding"]["input_ids"] == [ + [ + 2764, # ▁An: 10 + 21, 15999, # ▁, even: 11 -> span 3 + 9161, 108, # short er: 12 -> span 4 + 91, 35, # ▁on e: 13 -> span 4 + 83, # ▁!: 14 + ] + ] + # fmt: on + batch = pooler.collate(pivoted_prep) + out = pooler.forward(batch)["embeddings"] + + word_embeddings = pooler.embedding(batch["embedding"])["embeddings"] + assert word_embeddings.shape == (15, 768) + + assert out.shape == (5, 768) + + # item_indices: [0, 2, 3, 7, 8, 11, 12, 13] + # - ---- ---- -- ------ + # span_offsets: [0, 1, 3, 5, 6 ] + + assert torch.allclose(out[0], word_embeddings[0]) + assert torch.allclose(out[1], word_embeddings[2:4].mean(0)) + assert torch.allclose(out[2], word_embeddings[7:9].mean(0)) + assert torch.allclose(out[3], word_embeddings[11]) + assert torch.allclose(out[4], word_embeddings[12:14].mean(0)) diff --git a/tests/pipelines/trainable/test_span_qualifier.py b/tests/pipelines/trainable/test_span_qualifier.py index 66a75abc65..19be040387 100644 --- a/tests/pipelines/trainable/test_span_qualifier.py +++ b/tests/pipelines/trainable/test_span_qualifier.py @@ -49,7 +49,10 @@ def gold(): @pytest.mark.parametrize("with_constraints_and_not_none", [True, False]) -def test_span_qualifier(gold, with_constraints_and_not_none, tmp_path): +@pytest.mark.parametrize("word_pooling_mode", ["mean", False]) +def test_span_qualifier( + gold, with_constraints_and_not_none, word_pooling_mode, tmp_path +): import torch nlp = edsnlp.blank("eds") @@ -60,6 +63,7 @@ def test_span_qualifier(gold, with_constraints_and_not_none, tmp_path): model="prajjwal1/bert-tiny", window=128, stride=96, + word_pooling_mode=word_pooling_mode, ), ) nlp.add_pipe( @@ -69,6 +73,8 @@ def test_span_qualifier(gold, with_constraints_and_not_none, tmp_path): "embedding": { "@factory": "eds.span_pooler", "embedding": nlp.get_pipe("transformer"), + "norm": "layernorm", + "activation": "relu", }, "span_getter": ["ents", "sc"], "qualifiers": {"_.test_negated": True, "_.event_type": ("event",)} diff --git a/tests/pipelines/trainable/test_transformer.py b/tests/pipelines/trainable/test_transformer.py index f2802c3342..d9d9fd6b9d 100644 --- a/tests/pipelines/trainable/test_transformer.py +++ b/tests/pipelines/trainable/test_transformer.py @@ -1,8 +1,11 @@ import pytest +from confit.utils.random import set_seed from pytest import fixture from spacy.tokens import Span import edsnlp +import edsnlp.pipes as eds +from edsnlp.data.converters import MarkupToDocConverter from edsnlp.utils.collections import batch_compress_dict, decompress_dict if not Span.has_extension("label"): @@ -89,4 +92,39 @@ def test_span_getter(gold): batch = trf.collate(batch) batch = trf.batch_to_device(batch, device=trf.device) res = trf(batch) - assert res["embeddings"].shape == (2, 5, 128) + assert res["embeddings"].shape == (9, 128) + + +def test_transformer_pooling(): + nlp = edsnlp.blank("eds") + converter = MarkupToDocConverter(tokenizer=nlp.tokenizer) + doc1 = converter("These are small sentencesstuff.") + doc2 = converter("A tiny one.") + + def run_trf(word_pooling_mode): + set_seed(42) + trf = eds.transformer( + model="prajjwal1/bert-tiny", + window=128, + stride=96, + word_pooling_mode=word_pooling_mode, + ) + prep1 = trf.preprocess(doc1) + prep2 = trf.preprocess(doc2) + assert prep1["input_ids"] == [ + [2122, 2024, 2235, 11746, 3367, 16093, 2546, 1012] + ] + assert prep2["input_ids"] == [[1037, 4714, 2028, 1012]] + batch = decompress_dict(list(batch_compress_dict([prep1, prep2]))) + batch = trf.collate(batch) + return trf(batch) + + res_pool = run_trf(word_pooling_mode="mean") + assert res_pool["embeddings"].shape == (9, 128) + + res_flat = run_trf(word_pooling_mode=False) + assert res_flat["embeddings"].shape == (12, 128) + + # The second sequence is identical in both cases (only one subword per word) + # so the last 4 word/subword embeddings should be identical + assert res_pool["embeddings"][-4:].allclose(res_flat["embeddings"][-4:]) diff --git a/tests/training/ner_qlf_same_bert_config.yml b/tests/training/ner_qlf_same_bert_config.yml index 9d9aab8e96..060131ec2f 100644 --- a/tests/training/ner_qlf_same_bert_config.yml +++ b/tests/training/ner_qlf_same_bert_config.yml @@ -36,6 +36,7 @@ nlp: embedding: '@factory': eds.span_pooler + pooling_mode: attention embedding: # ${ nlp.components.ner.embedding } '@factory': eds.text_cnn