diff --git a/README.md b/README.md index 61e19fc..f4516d2 100644 --- a/README.md +++ b/README.md @@ -60,31 +60,29 @@ Dependencies project](https://universaldependencies.org/). UDTube can perform up to four morphological tasks simultaneously: -- Lemmatization is performed using the `LEMMA` field and [edit - scripts](https://aclanthology.org/P14-2111/). - -- [Universal part-of-speech - tagging](https://universaldependencies.org/u/pos/index.html) is performed - using the `UPOS` field: enable with `data: use_upos: true`. - -- Language-specific part-of-speech tagging is performed using the `XPOS` - field: enable with `data: use_xpos: true`. - -- Morphological feature tagging is performed using the `FEATS` field: enable - with `data: use_feats: true`. +- Lemmatization is performed using the `LEMMA` field and edit scripts. +- [Universal part-of-speech + tagging](https://universaldependencies.org/u/pos/index.html) is performed + using the `UPOS` field. +- Language-specific part-of-speech tagging is performed using the `XPOS` field. +- Morphological feature tagging is performed using the `FEATS` field. +- Dependency parsing is performed using the `HEAD` and `DEPREL` fields, a deep + biaffine parser, and minimum spanning tree decoding. The following caveats apply: -- Note that many newer Universal Dependencies datasets do not have - language-specific part-of-speech-tags. -- The `FEATS` field is treated as a single unit and is not segmented in any - way. -- One can convert from [Universal Dependencies morphological - features](https://universaldependencies.org/u/feat/index.html) to [UniMorph - features](https://unimorph.github.io/schema/) using - [`scripts/convert_to_um.py`](scripts/convert_to_um.py). -- UDTube does not perform dependency parsing at present, so the `HEAD`, - `DEPREL`, and `DEPS` fields are ignored and should be specified as `_`. +- By default, lemmatization uses reverse-edit scripts. This is appropriate for + predominantly suffixal languages, which are thought to represent the majority + of the world's languages. If working with a predominantly prefixal language, + disable this with `data: reverse_edits: false`. +- Note that many newer Universal Dependencies datasets do not have + language-specific part-of-speech-tags so this task should be disabled + (`data: use_xpos: false`). +- The `FEATS` field is treated as a single unit and is not segmented in any way. +- One can convert from [Universal Dependencies morphological + features](https://universaldependencies.org/u/feat/index.html) to [UniMorph + features](https://unimorph.github.io/schema/) using + [`scripts/convert_to_um.py`](scripts/convert_to_um.py). ## Usage @@ -189,7 +187,7 @@ information](https://github.com/CUNY-CL/yoyodyne/blob/master/README.md#logging). #### Other options -By default, UDTube attempts to model all four tasks; one can disable the +By default, UDTube attempts to model all five tasks; one can disable the language-specific tagging task using `model: use_xpos: false`, and so on. Dropout probability is specified using `model: dropout: ...`. @@ -198,11 +196,6 @@ The encoder has multiple layers. The input to the classifier consists of just the last few layers mean-pooled together. The number of layers used for mean-pooling is specified using `model: pooling_layers: ...`. -By default, lemmatization uses reverse-edit scripts. This is appropriate for -predominantly suffixal languages, which are thought to represent the majority of -the world's languages. If working with a predominantly prefixal language, -disable this with `model: reverse_edits: false`. - The following YAML snippet shows the default architectural arguments. ... @@ -210,13 +203,12 @@ The following YAML snippet shows the default architectural arguments. dropout: 0.5 encoder: google-bert/bert-base-multilingual-cased pooling_layers: 1 - reverse_edits: true use_upos: true use_xpos: true use_lemma: true use_feats: true + use_parse: true ... - Batch size is specified using `data: batch_size: ...` and defaults to 32. @@ -268,14 +260,14 @@ written. Here are some additional details: -- In `predict` mode UDTube loads the file to be labeled incrementally (i.e., - one sentence at a time) so this can be used with very large files. -- In `predict` mode, if no path for the predictions is specified, stdout will - be used. If using this in conjunction with \> or \|, add - `--trainer.enable_progress_bar false` on the command line. -- The target task fields are overriden if their heads are active. -- Use [`scripts/pretokenize.py`](scripts/pretokenize.py) to convert raw text - files to CoNLL-U input files. +- In `predict` mode UDTube loads the file to be labeled incrementally (i.e., one + sentence at a time) so this can be used with very large files. +- In `predict` mode, if no path for the predictions is specified, stdout will be + used. If using this in conjunction with \> or \|, add + `--trainer.enable_progress_bar false` on the command line. +- The target task fields are overriden if their heads are active. +- Use [`scripts/pretokenize.py`](scripts/pretokenize.py) to convert raw text + files to CoNLL-U input files. This mode is invoked using the `predict` subcommand, like so: @@ -322,3 +314,6 @@ following document, which describes the model: Yakubov, D. 2024. [How do we learn what we cannot say?](https://academicworks.cuny.edu/gc_etds/5622/) Master's thesis, CUNY Graduate Center. + +(See also [`udtube.bib`](udtube.bib) for more work used during the development +of this library.) diff --git a/configs/ewt_bert.yaml b/configs/ewt_bert.yaml index 60b637e..0a42533 100644 --- a/configs/ewt_bert.yaml +++ b/configs/ewt_bert.yaml @@ -22,12 +22,6 @@ trainer: model: dropout: 0.4 encoder: google-bert/bert-base-cased - pooling_layers: 4 - reverse_edits: true - use_upos: true - use_xpos: true - use_lemma: true - use_feats: true encoder_optimizer: class_path: torch.optim.Adam init_args: diff --git a/configs/ewt_distilbert.yaml b/configs/ewt_distilbert.yaml index 10a39f0..26e0bd9 100644 --- a/configs/ewt_distilbert.yaml +++ b/configs/ewt_distilbert.yaml @@ -22,12 +22,6 @@ trainer: model: dropout: 0.4 encoder: distilbert/distilbert-base-cased - pooling_layers: 4 - reverse_edits: true - use_upos: true - use_xpos: true - use_lemma: true - use_feats: true encoder_optimizer: class_path: torch.optim.Adam init_args: @@ -52,6 +46,7 @@ data: test: /Users/Shinji/UD_English-EWT/en_ewt-ud-test.conllu predict: /Users/Shinji/UD_English-EWT/en_ewt-ud-test.conllu batch_size: 32 + reverse_edits: true checkpoint: filename: "model-{epoch:03d}-{val_loss:.4f}" monitor: val_loss diff --git a/configs/ewt_roberta.yaml b/configs/ewt_roberta.yaml index 7fbe0eb..44df14d 100644 --- a/configs/ewt_roberta.yaml +++ b/configs/ewt_roberta.yaml @@ -22,12 +22,6 @@ trainer: model: dropout: 0.4 encoder: FacebookAI/roberta-base - pooling_layers: 4 - reverse_edits: true - use_upos: true - use_xpos: true - use_lemma: true - use_feats: true encoder_optimizer: class_path: torch.optim.Adam init_args: diff --git a/configs/syntagrus_mbert.yaml b/configs/syntagrus_mbert.yaml index 9118fd5..37c3673 100644 --- a/configs/syntagrus_mbert.yaml +++ b/configs/syntagrus_mbert.yaml @@ -22,12 +22,7 @@ trainer: model: dropout: 0.4 encoder: google-bert/bert-base-multilingual-cased - pooling_layers: 4 - reverse_edits: true - use_upos: true use_xpos: false - use_lemma: true - use_feats: true encoder_optimizer: class_path: torch.optim.Adam init_args: diff --git a/configs/syntagrus_rubert.yaml b/configs/syntagrus_rubert.yaml index ba740b9..91d156d 100644 --- a/configs/syntagrus_rubert.yaml +++ b/configs/syntagrus_rubert.yaml @@ -22,12 +22,7 @@ trainer: model: dropout: 0.4 encoder: DeepPavlov/rubert - pooling_layers: 4 - reverse_edits: true - use_upos: true use_xpos: false - use_lemma: true - use_feats: true encoder_optimizer: class_path: torch.optim.Adam init_args: diff --git a/configs/syntagrus_xlm-roberta.yaml b/configs/syntagrus_xlm-roberta.yaml index e6d2f5e..90293c2 100644 --- a/configs/syntagrus_xlm-roberta.yaml +++ b/configs/syntagrus_xlm-roberta.yaml @@ -22,12 +22,7 @@ trainer: model: dropout: 0.4 encoder: FacebookAI/xlm-roberta-base - pooling_layers: 4 - reverse_edits: true - use_upos: true use_xpos: false - use_lemma: true - use_feats: true encoder_optimizer: class_path: torch.optim.Adam init_args: diff --git a/examples/wandb_sweeps/configs/ewt_grid.yaml b/examples/wandb_sweeps/configs/ewt_grid.yaml index f5131a8..d755ce7 100644 --- a/examples/wandb_sweeps/configs/ewt_grid.yaml +++ b/examples/wandb_sweeps/configs/ewt_grid.yaml @@ -1,4 +1,4 @@ -method: random +method: bayes metric: name: val_loss goal: minimize @@ -10,6 +10,7 @@ parameters: min: 0 max: 0.5 model.encoder: + distribution: categorical values: - FacebookAI/roberta-base - distilbert/distilbert-base-cased @@ -18,7 +19,7 @@ parameters: distribution: q_uniform q: 1 min: 1 - max: 8 + max: 4 model.encoder_optimizer.class_path: value: torch.optim.Adam model.encoder_optimizer.init_args.lr: @@ -31,7 +32,7 @@ parameters: distribution: q_uniform q: 1 min: 1 - max: 20 + max: 40 model.classifier_optimizer.class_path: value: torch.optim.Adam model.classifier_optimizer.init_args.lr: @@ -49,6 +50,7 @@ parameters: model.classifier_scheduler.init_args.patience: value: 5 data.batch_size: + distribution: categorical values: - 8 - 16 diff --git a/examples/wandb_sweeps/configs/gdt_grid.yaml b/examples/wandb_sweeps/configs/gdt_grid.yaml index 4809c7e..6c47e38 100644 --- a/examples/wandb_sweeps/configs/gdt_grid.yaml +++ b/examples/wandb_sweeps/configs/gdt_grid.yaml @@ -1,4 +1,4 @@ -method: random +method: bayes metric: name: val_loss goal: minimize @@ -10,6 +10,7 @@ parameters: min: 0 max: 0.5 model.encoder: + distribution: categorical values: - google-bert/bert-base-multilingual-cased - FacebookAI/xlm-roberta-base @@ -17,7 +18,7 @@ parameters: distribution: q_uniform q: 1 min: 1 - max: 8 + max: 4 model.encoder_optimizer.class_path: value: torch.optim.Adam model.encoder_optimizer.init_args.lr: @@ -30,7 +31,7 @@ parameters: distribution: q_uniform q: 1 min: 1 - max: 20 + max: 40 model.classifier_optimizer.class_path: value: torch.optim.Adam model.classifier_optimizer.init_args.lr: @@ -48,6 +49,7 @@ parameters: model.classifier_scheduler.init_args.patience: value: 5 data.batch_size: + distribution: categorical values: - 8 - 16 diff --git a/examples/wandb_sweeps/configs/syntagrus_grid.yaml b/examples/wandb_sweeps/configs/syntagrus_grid.yaml index 05d27db..5b982c3 100644 --- a/examples/wandb_sweeps/configs/syntagrus_grid.yaml +++ b/examples/wandb_sweeps/configs/syntagrus_grid.yaml @@ -1,4 +1,4 @@ -method: random +method: bayes metric: name: val_loss goal: minimize @@ -10,6 +10,7 @@ parameters: min: 0 max: 0.5 model.encoder: + distribution: categorical values: - google-bert/bert-base-multilingual-cased - FacebookAI/xlm-roberta-base @@ -18,7 +19,7 @@ parameters: distribution: q_uniform q: 1 min: 1 - max: 8 + max: 4 model.encoder_optimizer.class_path: value: torch.optim.Adam model.encoder_optimizer.init_args.lr: @@ -31,7 +32,7 @@ parameters: distribution: q_uniform q: 1 min: 1 - max: 20 + max: 40 model.classifier_optimizer.class_path: value: torch.optim.Adam model.classifier_optimizer.init_args.lr: @@ -49,6 +50,7 @@ parameters: model.classifier_scheduler.init_args.patience: value: 5 data.batch_size: + distribution: categorical values: - 8 - 16 diff --git a/pyproject.toml b/pyproject.toml index e8daf1d..091988d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "udtube" -version = "0.1.12" +version = "0.2.0" description = "Neural morphological analysis" license = "Apache-2.0" readme = "README.md" diff --git a/tests/metrics_test.py b/tests/metrics_test.py new file mode 100644 index 0000000..7bd9e3c --- /dev/null +++ b/tests/metrics_test.py @@ -0,0 +1,218 @@ +"""Unit tests for dependency parsing metrics.""" + +import unittest + +import torch + +from udtube import metrics, special + + +class UnlabeledAttachmentScoreTest(unittest.TestCase): + + def setUp(self): + self.metric = metrics.UnlabeledAttachmentScore( + ignore_index=special.PAD_IDX + ) + + def test_perfect_accuracy(self): + pred_heads = torch.tensor([[3, 3, 4, 5]]) + target_heads = torch.tensor([[3, 3, 4, 5]]) + self.metric.update(pred_heads, target_heads) + self.assertEqual(self.metric.compute().item(), 1.0) + + def test_zero_accuracy(self): + pred_heads = torch.tensor([[3, 4, 5, 6]]) + target_heads = torch.tensor([[4, 5, 6, 3]]) + self.metric.update(pred_heads, target_heads) + self.assertEqual(self.metric.compute().item(), 0.0) + + def test_partial_accuracy(self): + pred_heads = torch.tensor([[3, 3, 6, 5]]) + target_heads = torch.tensor([[3, 3, 4, 5]]) + + def test_partial_accuracy(self): + pred_heads = torch.tensor([[0, 0, 3, 2]]) + target_heads = torch.tensor([[0, 0, 1, 2]]) + self.metric.update(pred_heads, target_heads) + self.assertAlmostEqual(self.metric.compute().item(), 0.75) + + def test_padding_ignored(self): + pred_heads = torch.tensor([[3, 3, 4, special.PAD_IDX]]) + target_heads = torch.tensor([[3, 3, 4, special.PAD_IDX]]) + self.metric.update(pred_heads, target_heads) + self.assertEqual(self.metric.compute().item(), 1.0) + + def test_padding_does_not_affect_score(self): + pred_heads = torch.tensor([[3, 3, 4, 5]]) + target_heads = torch.tensor([[3, 3, 4, special.PAD_IDX]]) + self.metric.update(pred_heads, target_heads) + self.assertEqual(self.metric.compute().item(), 1.0) + + def test_multiple_batches(self): + pred_heads1 = torch.tensor([[3, 3, 4, 5]]) + target_heads1 = torch.tensor([[3, 3, 4, 5]]) + self.metric.update(pred_heads1, target_heads1) + pred_heads2 = torch.tensor([[3, 3, 6, 5]]) + target_heads2 = torch.tensor([[3, 3, 4, 5]]) + self.metric.update(pred_heads2, target_heads2) + self.assertAlmostEqual(self.metric.compute().item(), 0.875) + + def test_reset(self): + pred_heads = torch.tensor([[3, 3, 4, 5]]) + target_heads = torch.tensor([[3, 3, 4, 5]]) + self.metric.update(pred_heads, target_heads) + self.assertEqual(self.metric.compute().item(), 1.0) + self.metric.reset() + self.assertEqual(self.metric.compute().item(), 0.0) + + def test_empty_metric(self): + self.assertEqual(self.metric.compute().item(), 0.0) + + def test_batch_dimension(self): + pred_heads = torch.tensor( + [ + [3, 3, 4, 5], + [3, 3, 6, 5], + [3, 3, 4, 5], + ] + ) + target_heads = torch.tensor( + [ + [3, 3, 4, 5], + [3, 3, 4, 5], + [3, 3, 4, 5], + ] + ) + self.metric.update(pred_heads, target_heads) + self.assertAlmostEqual(self.metric.compute().item(), 11.0 / 12.0) + + +class LabeledAttachmentScoreTest(unittest.TestCase): + + def setUp(self): + self.metric = metrics.LabeledAttachmentScore( + ignore_index=special.PAD_IDX + ) + + def test_perfect_accuracy(self): + pred_heads = torch.tensor([[3, 3, 4, 5]]) + target_heads = torch.tensor([[3, 3, 4, 5]]) + pred_labels = torch.tensor([[8, 13, 6, 10]]) + target_labels = torch.tensor([[8, 13, 6, 10]]) + self.metric.update( + pred_heads, target_heads, pred_labels, target_labels + ) + self.assertEqual(self.metric.compute().item(), 1.0) + + def test_zero_accuracy(self): + pred_heads = torch.tensor([[3, 4, 5, 6]]) + target_heads = torch.tensor([[3, 6, 4, 3]]) + pred_labels = torch.tensor([[4, 5, 6, 7]]) + target_labels = torch.tensor([[8, 9, 10, 11]]) + self.metric.update( + pred_heads, target_heads, pred_labels, target_labels + ) + self.assertEqual(self.metric.compute().item(), 0.0) + + def test_correct_heads_wrong_labels(self): + pred_heads = torch.tensor([[3, 3, 4, 5]]) + target_heads = torch.tensor([[3, 3, 4, 5]]) + pred_labels = torch.tensor([[4, 5, 6, 7]]) + target_labels = torch.tensor([[8, 9, 10, 11]]) + self.metric.update( + pred_heads, target_heads, pred_labels, target_labels + ) + self.assertEqual(self.metric.compute().item(), 0.0) + + def test_wrong_heads_correct_labels(self): + pred_heads = torch.tensor([[3, 4, 5, 6]]) + target_heads = torch.tensor([[3, 3, 4, 5]]) + pred_labels = torch.tensor([[8, 13, 6, 10]]) + target_labels = torch.tensor([[8, 13, 6, 10]]) + self.metric.update( + pred_heads, target_heads, pred_labels, target_labels + ) + self.assertEqual(self.metric.compute().item(), 0.25) + + def test_partial_accuracy(self): + pred_heads = torch.tensor([[3, 3, 6, 5]]) + target_heads = torch.tensor([[3, 3, 4, 5]]) + pred_labels = target_labels = torch.tensor([[8, 13, 6, 10]]) + self.metric.update( + pred_heads, target_heads, pred_labels, target_labels + ) + self.assertAlmostEqual(self.metric.compute().item(), 0.75) + + def test_padding_ignored(self): + pred_heads = torch.tensor([[3, 3, 4, special.PAD_IDX]]) + target_heads = torch.tensor([[3, 3, 4, special.PAD_IDX]]) + pred_labels = torch.tensor([[8, 13, 6, special.PAD_IDX]]) + target_labels = torch.tensor([[8, 13, 6, special.PAD_IDX]]) + self.metric.update( + pred_heads, target_heads, pred_labels, target_labels + ) + self.assertEqual(self.metric.compute().item(), 1.0) + + def test_multiple_batches(self): + pred_heads1 = target_heads1 = torch.tensor([[3, 3, 4, 5]]) + pred_labels1 = torch.tensor([[8, 13, 4, 10]]) + target_labels1 = torch.tensor([[8, 13, 6, 10]]) + self.metric.update( + pred_heads1, target_heads1, pred_labels1, target_labels1 + ) + pred_heads2 = target_heads2 = torch.tensor([[3, 3, 6, 5]]) + pred_labels2 = torch.tensor([[8, 13, 5, 10]]) + target_labels2 = torch.tensor([[8, 13, 6, 10]]) + self.metric.update( + pred_heads2, target_heads2, pred_labels2, target_labels2 + ) + self.assertAlmostEqual(self.metric.compute().item(), 0.75) + + def test_requires_labels(self): + pred_heads = torch.tensor([[3, 3, 4, 5]]) + target_heads = torch.tensor([[3, 3, 4, 5]]) + with self.assertRaises(AssertionError): + self.metric.update(pred_heads, target_heads) + + def test_reset(self): + pred_heads = torch.tensor([[3, 3, 4, 5]]) + target_heads = torch.tensor([[3, 3, 4, 5]]) + pred_labels = torch.tensor([[8, 13, 6, 10]]) + target_labels = torch.tensor([[8, 13, 6, 10]]) + self.metric.update( + pred_heads, target_heads, pred_labels, target_labels + ) + self.assertEqual(self.metric.compute().item(), 1.0) + self.metric.reset() + self.assertEqual(self.metric.compute().item(), 0.0) + + def test_batch_dimension(self): + pred_heads = target_heads = torch.tensor( + [ + [3, 3, 4, 5], + [3, 3, 4, 5], + [3, 3, 4, 5], + ] + ) + pred_labels = torch.tensor( + [ + [8, 13, 6, 10], + [8, 13, 5, 10], + [8, 13, 6, 10], + ] + ) + target_labels = torch.tensor( + [ + [8, 13, 6, 10], + [8, 13, 6, 10], + [8, 13, 6, 10], + ] + ) + self.metric.update( + pred_heads, target_heads, pred_labels, target_labels + ) + self.assertAlmostEqual(self.metric.compute().item(), 11.0 / 12.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/udtube.bib b/udtube.bib new file mode 100644 index 0000000..476fb5e --- /dev/null +++ b/udtube.bib @@ -0,0 +1,32 @@ +@mastersthesis{Yakubov:24, + author = {Yakubov, Daniel}, + year = {2024}, + title = {How do we learn what we cannot say?}, + school = {CUNY Graduate Center}} + +@inproceedings{Chrupala:14, + author = {Chrupała, Grzegorz}, + year = {2014}, + title = {Normalizing tweets with edit scripts and recurrent neural embeddings}, + booktitle = {Proceedings of the 52nd Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, + pages = {680-686}} + +@inproceedings{Dozat:Manning:17, + author = {Dozat, Timothy and Manning, Chris D.}, + year = {2017}, + title = {Deep biaffine attention for dependency parsing}, + booktitle = {ICLR}} + +@inproceedings{Straka:18, + author = {Straka, Milan}, + year = {2018}, + title = {{UDPipe 2.0} prototype at {CoNLL 2018 UD} shared task}, + booktitle = {Proceedings of the CoNLL 2018 Shared Task: Multilingual Parsing from Raw Text to Universal Dependencies}, + pages = {197-207}} + +@inproceedings{Straka:EtAl:19, + author = {Straka, Milan and Straková, Jana and Hajic, Jan}, + year = {2019}, + title = {{UDPipe 2.0} at {SIGMORPHON 2019}: Contextualized embeddings, regularization with morphological categories, corpora meging}, + booktitle = {Proceedings of the 16th Workshop on Computational Research in Phonetics, Phonology, and Morphology}, + pages = {95-103}} diff --git a/udtube/__init__.py b/udtube/__init__.py index f61bb5c..7f51855 100644 --- a/udtube/__init__.py +++ b/udtube/__init__.py @@ -1,7 +1,4 @@ -"""UDTube: a neural morphological analyzer. - -This module just silences some uninformative warnings. -""" +"""UDTube: a neural morphological analyzer.""" import os import warnings @@ -9,13 +6,12 @@ # Silences tokenizers warning about forking. os.environ["TOKENIZERS_PARALLELISM"] = "false" - # Silences some stupid warnings. -warnings.filterwarnings("ignore", ".*both args and command line arguments.*") -warnings.filterwarnings("ignore", ".*need to be provided during `Trainer`.*") -warnings.filterwarnings("ignore", ".*is a wandb run already in progress.*") -warnings.filterwarnings("ignore", ".*`tensorboardX` has been removed.*") -warnings.filterwarnings("ignore", ".*does not have many workers.*") -warnings.filterwarnings("ignore", ".*Couldn't infer the batch indices.*") -warnings.filterwarnings("ignore", ".*in eval mode at the start of training.*") -warnings.filterwarnings("ignore", ".*smaller than the logging interval.*") +warnings.filterwarnings("ignore", r"(?s).*both args and command line.*") +warnings.filterwarnings("ignore", r"(?s).*need to be provided during.*") +warnings.filterwarnings("ignore", r"(?s).*is a wandb run already in.*") +warnings.filterwarnings("ignore", r"(?s).*`tensorboardX` has been removed.*") +warnings.filterwarnings("ignore", r"(?s).*does not have many workers.*") +warnings.filterwarnings("ignore", r"(?s).*Couldn't infer the batch indices.*") +warnings.filterwarnings("ignore", r"(?s).*in eval mode at the start of.*") +warnings.filterwarnings("ignore", r"(?s).*smaller than the logging interval.*") diff --git a/udtube/callbacks.py b/udtube/callbacks.py index 42d42f9..a902de7 100644 --- a/udtube/callbacks.py +++ b/udtube/callbacks.py @@ -2,7 +2,7 @@ import logging import sys -from typing import Iterator, Optional, Sequence, TextIO +from typing import Iterator, Optional, Sequence, TextIO, Tuple import lightning from lightning.pytorch import callbacks, trainer @@ -36,7 +36,9 @@ def __init__( # Required API. def on_predict_start( - self, trainer: trainer.Trainer, pl_module: lightning.LightningModule + self, + trainer: trainer.Trainer, + pl_module: lightning.LightningModule, ) -> None: # Placing this here prevents the creation of an empty file in the case # where a prediction callback was specified but UDTube is not running @@ -48,12 +50,13 @@ def write_on_batch_end( self, trainer: trainer.Trainer, model: models.UDTube, - logits: data.Logits, + logits_mask: Tuple[data.Logits, torch.Tensor], batch_indices: Optional[Sequence[int]], batch: data.Batch, batch_idx: int, dataloader_idx: int, ) -> None: + logits, mask = logits_mask mapper = data.Mapper(trainer.datamodule.index) # Batch-level argmax on the classification heads. upos_hat = ( @@ -68,6 +71,9 @@ def write_on_batch_end( feats_hat = ( torch.argmax(logits.feats, dim=1) if logits.use_feats else None ) + head_hat, deprel_hat = model.classifier.parse_head.decode( + logits.head, logits.deprel, mask + ) for i, tokenlist in enumerate(batch.tokenlists): # Sentence-level decoding of the classification indices, followed # by rewriting the fields in the tokenlist. @@ -85,19 +91,25 @@ def write_on_batch_end( if feats_hat is not None: feats_it = mapper.decode_feats(feats_hat[i, :]) self._fill_in_tags(tokenlist, "feats", feats_it) + if head_hat is not None and deprel_hat is not None: + head_it = mapper.decode_head(head_hat[i, :]) + deprel_it = mapper.decode_deprel(deprel_hat[i, :]) + self._fill_in_parse(tokenlist, head_it, deprel_it) print(tokenlist, file=self.sink) self.sink.flush() @staticmethod def _fill_in_tags( - tokenlist: data.conllu.TokenList, attr: str, tags: Iterator[str] + tokenlist: data.conllu.TokenList, + attr: str, + tags: Iterator[str], ) -> None: """Helper method for copying tags into tokenlist. Args: - tokenlist (data.conllu.TokenList): tokenlist to insert into. - attr (str): attribute on tokens where the tags should be inserted. - tags (Iterator[str]): tags to insert. + tokenlist: tokenlist to insert into. + attr: attribute on tokens where the tags should be inserted. + tags: iterator over tags to insert. """ # Note that when MWEs are present, the iterators with predicted tags # from the classifier heads are shorter than the tokenlists, so we @@ -110,13 +122,45 @@ def _fill_in_tags( except StopIteration: # Prevents the error from being caught by Lightning. logging.error( - f"Length mismatch at tag {attr!r} (sent_id: " + f"Length mismatch at token {token!r} (sent_id: " + f"{tokenlist.metadata.get('sent_id')})" + ) + continue + + @staticmethod + def _fill_in_parse( + tokenlist: data.conllu.TokenList, + head: Iterator[str], + deprel: Iterator[str], + ) -> None: + """Helper method for copying parser data into tokenlist. + + Args: + tokenlist: tokenlist to insert into. + head: iterator over heads. + deprel: iterator over dependency relations. + """ + # Note that when MWEs are present, the iterators with predicted tags + # from the classifier heads are shorter than the tokenlists, so we + # `continue` without advancing said iterators. + for token in tokenlist: + if token.is_mwe: + continue + try: + token.head = next(head) + token.deprel = next(deprel) + except StopIteration: + # Prevents the error from being caught by Lightning. + logging.error( + f"Length mismatch at token {token!r} (sent_id: " f"{tokenlist.metadata.get('sent_id')})" ) continue def on_predict_end( - self, trainer: trainer.Trainer, pl_module: lightning.LightningModule + self, + trainer: trainer.Trainer, + pl_module: lightning.LightningModule, ) -> None: if self.sink is not sys.stdout: self.sink.close() diff --git a/udtube/cli.py b/udtube/cli.py index 201fd25..8b9f1b6 100644 --- a/udtube/cli.py +++ b/udtube/cli.py @@ -31,11 +31,11 @@ def add_arguments_to_parser( parser.link_arguments( "data.model_dir", "trainer.logger.init_args.save_dir" ) - parser.link_arguments("model.reverse_edits", "data.reverse_edits") parser.link_arguments("model.use_upos", "data.use_upos") parser.link_arguments("model.use_xpos", "data.use_xpos") parser.link_arguments("model.use_lemma", "data.use_lemma") parser.link_arguments("model.use_feats", "data.use_feats") + parser.link_arguments("model.use_parse", "data.use_parse") parser.link_arguments( "data.upos_tagset_size", "model.upos_out_size", @@ -56,6 +56,11 @@ def add_arguments_to_parser( "model.feats_out_size", apply_on="instantiate", ) + parser.link_arguments( + "data.deprel_tagset_size", + "model.deprel_out_size", + apply_on="instantiate", + ) def main() -> None: diff --git a/udtube/data/batches.py b/udtube/data/batches.py index 9b99a0c..4dc2dbf 100644 --- a/udtube/data/batches.py +++ b/udtube/data/batches.py @@ -12,15 +12,16 @@ class Batch(nn.Module): """CoNLL-U data batch. - This can handle padded label tensors if present. - Args: tokenlists: list of TokenLists. tokens: batch encoding from the transformer. - pos: optional padded tensor of universal POS labels. - xpos: optional padded tensor of language-specific POS labels. - lemma: optional padded tensor of lemma labels. - feats: optional padded tensor of morphological feature labels. + pos: optional padded tensor of universal POS tags. + xpos: optional padded tensor of language-specific POS tag. + lemma: optional padded tensor of lemma tags. + feats: optional padded tensor of morphological feature tags. + head: optional padded tensor of dependency parser head indices. + deprel: optional padded tensor of dependency parser dependency + relations. """ tokenlists: List[conllu.TokenList] @@ -31,6 +32,8 @@ class Batch(nn.Module): xpos: Optional[torch.Tensor] lemma: Optional[torch.Tensor] feats: Optional[torch.Tensor] + head: Optional[torch.Tensor] + deprel: Optional[torch.Tensor] def __init__( self, @@ -42,6 +45,8 @@ def __init__( xpos=None, lemma=None, feats=None, + head=None, + deprel=None, ): super().__init__() self.tokenlists = tokenlists @@ -52,6 +57,8 @@ def __init__( self.register_buffer("xpos", xpos) self.register_buffer("lemma", lemma) self.register_buffer("feats", feats) + self.register_buffer("head", head) + self.register_buffer("deprel", deprel) @property def use_upos(self) -> bool: @@ -69,5 +76,9 @@ def use_lemma(self) -> bool: def use_feats(self) -> bool: return self.feats is not None + @property + def use_parse(self) -> bool: + return self.head is not None and self.deprel is not None + def __len__(self) -> int: return len(self.tokenlists) diff --git a/udtube/data/collators.py b/udtube/data/collators.py index 5d675e6..6131795 100644 --- a/udtube/data/collators.py +++ b/udtube/data/collators.py @@ -93,6 +93,16 @@ def __call__(self, itemlist: List[datasets.Item]) -> batches.Batch: if itemlist[0].use_feats else None ), + head=( + self.pad_tensors([item.head for item in itemlist]) + if itemlist[0].use_parse + else None + ), + deprel=( + self.pad_tensors([item.deprel for item in itemlist]) + if itemlist[0].use_parse + else None + ), ) @staticmethod @@ -126,7 +136,9 @@ def pad_tensors( return torch.stack( [ nn.functional.pad( - tensor, (0, pad_max - len(tensor)), value=special.PAD_IDX + tensor, + (0, pad_max - len(tensor)), + value=special.PAD_IDX, ) for tensor in tensorlist ] diff --git a/udtube/data/conllu.py b/udtube/data/conllu.py index 43520c4..c902d90 100644 --- a/udtube/data/conllu.py +++ b/udtube/data/conllu.py @@ -10,7 +10,15 @@ import dataclasses import re -from typing import Dict, Iterable, Iterator, List, Optional, TextIO, Tuple +from typing import ( + Dict, + Iterable, + Iterator, + List, + Optional, + TextIO, + Tuple, +) from .. import special @@ -126,9 +134,18 @@ class Token: @classmethod def parse_from_string(cls, string: str) -> Token: - id_, form, lemma, upos, xpos, feats, head, deprel, deps, misc = ( - string.split("\t") - ) + ( + id_, + form, + lemma, + upos, + xpos, + feats, + head, + deprel, + deps, + misc, + ) = string.split("\t") return cls( ID.parse_from_string(id_), form, diff --git a/udtube/data/datamodules.py b/udtube/data/datamodules.py index d7143be..339cf20 100644 --- a/udtube/data/datamodules.py +++ b/udtube/data/datamodules.py @@ -36,6 +36,7 @@ class DataModule(lightning.LightningDataModule): use_xpos: Enables the language-specific POS tagging task. use_lemma: Enables the lemmatization task. use_feats: Enables the morphological feature tagging task. + use_parse: Enables the dependency parser task. batch_size: Batch size. """ @@ -48,6 +49,7 @@ class DataModule(lightning.LightningDataModule): use_xpos: bool use_lemma: bool use_feats: bool + use_parse: bool batch_size: int index: indexes.Index tokenizer: transformers.AutoTokenizer @@ -68,6 +70,7 @@ def __init__( use_xpos: bool = defaults.USE_XPOS, use_lemma: bool = defaults.USE_LEMMA, use_feats: bool = defaults.USE_FEATS, + use_parse: bool = defaults.USE_PARSE, # Other. batch_size: int = defaults.BATCH_SIZE, ): @@ -81,6 +84,7 @@ def __init__( self.use_xpos = use_xpos self.use_lemma = use_lemma self.use_feats = use_feats + self.use_parse = use_parse self.batch_size = batch_size # If the training data is specified, it is used to create (or recreate) # the index; if not specified it is read from the model directory. @@ -108,6 +112,7 @@ def _make_index(self, model_dir: str) -> indexes.Index: lemma_vocabulary = set() if self.use_lemma else None feats_vocabulary = set() if self.use_feats else None lemma_mapper = mappers.LemmaMapper(self.reverse_edits) + deprel_vocabulary = set() if self.use_parse else None for tokenlist in conllu.parse_from_path(self.train): if self.use_upos: upos_vocabulary.update(token.upos for token in tokenlist) @@ -120,6 +125,8 @@ def _make_index(self, model_dir: str) -> indexes.Index: ) if self.use_feats: feats_vocabulary.update(token.feats for token in tokenlist) + if self.use_parse: + deprel_vocabulary.update(token.deprel for token in tokenlist) index = indexes.Index( reverse_edits=self.reverse_edits, upos=( @@ -138,6 +145,11 @@ def _make_index(self, model_dir: str) -> indexes.Index: if self.use_feats else None ), + deprel=( + indexes.Vocabulary(deprel_vocabulary) + if self.use_parse + else None + ), ) # Writes it to the model directory. os.makedirs(model_dir, exist_ok=True) @@ -195,6 +207,10 @@ def lemma_tagset_size(self) -> int: def feats_tagset_size(self) -> int: return len(self.index.feats) if self.use_feats else 0 + @property + def deprel_tagset_size(self) -> int: + return len(self.index.deprel) if self.use_parse else 0 + # Required API. # The training set uses the mappable dataset because of shuffling, and @@ -212,6 +228,7 @@ def train_dataloader(self) -> data.DataLoader: self.use_xpos, self.use_lemma, self.use_feats, + self.use_parse, ), collate_fn=self.collator, batch_size=self.batch_size, @@ -230,6 +247,7 @@ def val_dataloader(self) -> data.DataLoader: self.use_xpos, self.use_lemma, self.use_feats, + self.use_parse, sequential=True, ), collate_fn=self.collator, @@ -260,6 +278,7 @@ def test_dataloader(self) -> data.DataLoader: self.use_xpos, self.use_lemma, self.use_feats, + self.use_parse, ), collate_fn=self.collator, batch_size=self.batch_size, diff --git a/udtube/data/datasets.py b/udtube/data/datasets.py index 440f89c..8b1827a 100644 --- a/udtube/data/datasets.py +++ b/udtube/data/datasets.py @@ -21,9 +21,18 @@ class Item(nn.Module): xpos: Optional[torch.Tensor] lemma: Optional[torch.Tensor] feats: Optional[torch.Tensor] + head: Optional[torch.Tensor] + deprel: Optional[torch.Tensor] def __init__( - self, tokenlist, upos=None, xpos=None, lemma=None, feats=None + self, + tokenlist, + upos=None, + xpos=None, + lemma=None, + feats=None, + head=None, + deprel=None, ): super().__init__() self.tokenlist = tokenlist @@ -31,6 +40,8 @@ def __init__( self.register_buffer("xpos", xpos) self.register_buffer("lemma", lemma) self.register_buffer("feats", feats) + self.register_buffer("head", head) + self.register_buffer("deprel", deprel) def get_tokens(self) -> List[str]: return self.tokenlist.get_tokens() @@ -51,6 +62,10 @@ def use_lemma(self) -> bool: def use_feats(self) -> bool: return self.feats is not None + @property + def use_parse(self) -> bool: + return self.head is not None and self.deprel is not None + @dataclasses.dataclass class AbstractDataset(abc.ABC): @@ -80,6 +95,7 @@ class AbstractTaggedDataset(AbstractDataset): use_xpos: bool use_lemma: bool use_feats: bool + use_parse: bool def tokenlist_to_item(self, tokenlist: conllu.TokenList) -> Item: return Item( @@ -113,6 +129,20 @@ def tokenlist_to_item(self, tokenlist: conllu.TokenList) -> Item: if self.use_feats else None ), + head=( + self.mapper.encode_head( + token.head for token in tokenlist if not token.is_mwe + ) + if self.use_parse + else None + ), + deprel=( + self.mapper.encode_deprel( + token.deprel for token in tokenlist if not token.is_mwe + ) + if self.use_parse + else None + ), ) diff --git a/udtube/data/indexes.py b/udtube/data/indexes.py index fd05050..f1126a8 100644 --- a/udtube/data/indexes.py +++ b/udtube/data/indexes.py @@ -75,6 +75,8 @@ class Index: xpos: optional vocabulary for language-specific POS tagging. lemma: optional vocabulary for lemmatization. feats: optional vocabulary for morphological tagging. + deprel: optional vocabulary for dependency parsing dependency + relations. """ reverse_edits: bool = defaults.REVERSE_EDITS @@ -82,6 +84,7 @@ class Index: xpos: Optional[Vocabulary] = None lemma: Optional[Vocabulary] = None feats: Optional[Vocabulary] = None + deprel: Optional[Vocabulary] = None # Serialization. diff --git a/udtube/data/logits.py b/udtube/data/logits.py index 62e7522..673c245 100644 --- a/udtube/data/logits.py +++ b/udtube/data/logits.py @@ -15,13 +15,25 @@ class Logits(nn.Module): xpos: Optional[torch.Tensor] lemma: Optional[torch.Tensor] feats: Optional[torch.Tensor] - - def __init__(self, upos=None, xpos=None, lemma=None, feats=None): + head: Optional[torch.Tensor] + deprel: Optional[torch.Tensor] + + def __init__( + self, + upos=None, + xpos=None, + lemma=None, + feats=None, + head=None, + deprel=None, + ): super().__init__() self.register_buffer("upos", upos) self.register_buffer("xpos", xpos) self.register_buffer("lemma", lemma) self.register_buffer("feats", feats) + self.register_buffer("head", head) + self.register_buffer("deprel", deprel) @property def use_upos(self) -> bool: @@ -38,3 +50,7 @@ def use_lemma(self) -> bool: @property def use_feats(self) -> bool: return self.feats is not None + + @property + def use_parse(self) -> bool: + return self.head is not None and self.deprel is not None diff --git a/udtube/data/mappers.py b/udtube/data/mappers.py index 2166442..a26353c 100644 --- a/udtube/data/mappers.py +++ b/udtube/data/mappers.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses -from typing import Iterable, Iterator +from typing import Callable, Iterable, Iterator import torch @@ -48,41 +48,42 @@ def __post_init__(self): @staticmethod def _encode( - labels: Iterable[str], - vocabulary: indexes.Vocabulary, + strings: Iterable[str], + functor: Callable[[str], int], ) -> torch.Tensor: """Encodes a tensor. Args: - labels: iterable of labels. - vocabulary: a vocabulary. + strings: iterable of strings. + functor: a callable mapping from strings to integers; usually + this is the vocabulary object. Returns: - Tensor of encoded labels. + Tensor of encoded strings """ - return torch.tensor([vocabulary(label) for label in labels]) + return torch.tensor([functor(string) for string in strings]) - def encode_upos(self, labels: Iterable[str]) -> torch.Tensor: + def encode_upos(self, tags: Iterable[str]) -> torch.Tensor: """Encodes universal POS tags. Args: - labels: iterable of universal POS strings. + tags: iterable of universal POS strings. Returns: - Tensor of encoded labels. + Tensor of encoded tags. """ - return self._encode(labels, self.index.upos) + return self._encode(tags, self.index.upos) - def encode_xpos(self, labels: Iterable[str]) -> torch.Tensor: + def encode_xpos(self, tags: Iterable[str]) -> torch.Tensor: """Encodes language-specific POS tags. Args: - labels: iterable of label-specific POS strings. + tags: iterable of language-specific POS strings. Returns: - Tensor of encoded labels. + Tensor of encoded tags. """ - return self._encode(labels, self.index.xpos) + return self._encode(tags, self.index.xpos) def encode_lemma( self, forms: Iterable[str], lemmas: Iterable[str] @@ -94,7 +95,7 @@ def encode_lemma( lemmas: iterable of lemmas. Returns: - Tensor of encoded labels. + Tensor of encoded lemma tags. """ return self._encode( [ @@ -104,40 +105,61 @@ def encode_lemma( self.index.lemma, ) - def encode_feats(self, labels: Iterable[str]) -> torch.Tensor: + def encode_feats(self, tags: Iterable[str]) -> torch.Tensor: """Encodes morphological feature tags. Args: - labels: iterable of feature tags. + tags: iterable of feature tags. Returns: - Tensor of encoded labels. + Tensor of encoded features. """ - return self._encode(labels, self.index.feats) + return self._encode(tags, self.index.feats) + + def encode_head(self, indices: Iterable[str]) -> torch.Tensor: + """Encodes dependency parsing head indices. + + Args: + indices: iterable of head indices. + + Returns: + Tensor of encoded head indices. + """ + # Cheeky, but it works. + return self._encode(indices, lambda idx: int(idx) + special.OFFSET) + + def encode_deprel(self, deprel: Iterable[str]) -> torch.Tensor: + """Encodes dependency parsing dependency relations. + + Args: + deprel: iterable of dependency relations. + + Returns: + Tensor of encoded dependency relations. + """ + return self._encode(deprel, self.index.deprel) # Decoding. @staticmethod def _decode( indices: torch.Tensor, - vocabulary: indexes.Vocabulary, + functor: Callable[[int], str], ) -> Iterator[str]: """Decodes a tensor. Args: indices: tensor of indices. - vocabulary: the vocabulary + functor: a callable mapping from strings to integers; usually + this is the vocabulary object's `get_symbol` method. Yields: - str: decoded symbols. + Decoded symbols. """ for idx in indices: - if idx == special.PAD_IDX: - # To avoid sequence length mismatches, - # _ is yielded for anything classified as a pad. - yield "_" - else: - yield vocabulary.get_symbol(idx) + # FIXME(kbg): is this the right thing to do? Do I need a special + # case for padding? + yield functor(idx) def decode_upos(self, indices: torch.Tensor) -> Iterator[str]: """Decodes an upos tensor. @@ -146,9 +168,9 @@ def decode_upos(self, indices: torch.Tensor) -> Iterator[str]: indices: tensor of indices. Yields: - str: decoded upos tags. + Decoded upos tags. """ - return self._decode(indices, self.index.upos) + return self._decode(indices, self.index.upos.get_symbol) def decode_xpos(self, indices: torch.Tensor) -> Iterator[str]: """Decodes an xpos tensor. @@ -157,9 +179,9 @@ def decode_xpos(self, indices: torch.Tensor) -> Iterator[str]: indices: tensor of indices. Yields: - str: decoded xpos tags. + Decoded xpos tags. """ - return self._decode(indices, self.index.xpos) + return self._decode(indices, self.index.xpos.get_symbol) def decode_lemma( self, forms: Iterable[str], indices: torch.Tensor @@ -171,9 +193,11 @@ def decode_lemma( indices: tensor of indices. Yields: - str: decoded lemmas. + Decoded lemmas. """ - for form, tag in zip(forms, self._decode(indices, self.index.lemma)): + for form, tag in zip( + forms, self._decode(indices, self.index.lemma.get_symbol) + ): yield self.lemma_mapper.lemmatize(form, tag) def decode_feats(self, indices: torch.Tensor) -> Iterator[str]: @@ -183,6 +207,30 @@ def decode_feats(self, indices: torch.Tensor) -> Iterator[str]: indices: tensor of indices. Yields: - str: decoded morphological features. + Decoded morphological features. + """ + return self._decode(indices, self.index.feats.get_symbol) + + def decode_head(self, indices: torch.Tensor) -> Iterator[str]: + """Encodes dependency parsing head indices. + + Args: + indices: iterable of head indices. + + Returns: + Decoded head indices. + """ + return self._decode( + indices, lambda idx: str(idx.item() - special.OFFSET) + ) + + def decode_deprel(self, indices: torch.Tensor) -> Iterator[str]: + """Decodes dependency parsing dependency relations. + + Args: + indices: tensor of indices. + + Yields: + Decoded dependency relations. """ - return self._decode(indices, self.index.feats) + return self._decode(indices, self.index.deprel.get_symbol) diff --git a/udtube/defaults.py b/udtube/defaults.py index 551129d..2f552c6 100644 --- a/udtube/defaults.py +++ b/udtube/defaults.py @@ -2,17 +2,23 @@ from yoyodyne import optimizers, schedulers +# Scalar constants. +NEG_EPSILON = -1e7 + # Default text encoding. ENCODING = "utf-8" # Architecture arguments. ENCODER = "google-bert/bert-base-multilingual-cased" POOLING_LAYERS = 1 -REVERSE_EDITS = True +ARC_MLP_SIZE = 512 +DEPREL_MLP_SIZE = 128 USE_UPOS = True USE_XPOS = True USE_LEMMA = True USE_FEATS = True +USE_PARSE = True +REVERSE_EDITS = True # Training arguments. BATCH_SIZE = 32 diff --git a/udtube/metrics.py b/udtube/metrics.py new file mode 100644 index 0000000..7646159 --- /dev/null +++ b/udtube/metrics.py @@ -0,0 +1,100 @@ +"""Custom metrics. + +This module implements unlabeled attachment score (UAS) and labeled attachment +score (LAS) metrics used for dependency parsing evaluation. +""" + +import torch +import torchmetrics + + +class AttachmentScore(torchmetrics.Metric): + """Base class for attachment scores. + + This metric computes the percentage of tokens that have the correct head + (and optionally, the correct dependency relation). + """ + + def __init__(self, labeled: bool, ignore_index: int): + """Initialize the attachment score metric. + + Args: + labeled: compute LAS rather than UAS? + ignore_index: index used for padding. + """ + super().__init__() + self.labeled = labeled + self.ignore_index = ignore_index + self.add_state( + "correct", default=torch.tensor(0), dist_reduce_fx="sum" + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update( + self, + hypo_head: torch.Tensor, + gold_head: torch.Tensor, + hypo_deprel: torch.Tensor | None = None, + gold_deprel: torch.Tensor | None = None, + ) -> None: + """Accumulates sufficient statistics for a batch. + + Args: + hypo_head. + gold_head. + hypo_deprel: required if `labeled=True`. + gold_deprel: required if `labeled=True`. + """ + assert hypo_head.shape == gold_head.shape, ( + f"Shape mismatch: hypo_head {hypo_head.shape} " + f"!= gold_head {gold_head.shape}" + ) + if self.labeled: + assert ( + hypo_deprel is not None and gold_deprel is not None + ), "Labels required for labeled attachment score" + assert hypo_deprel.shape == gold_deprel.shape, ( + f"Shape mismatch: hypo_deprel {hypo_deprel.shape} " + f"!= gold_deprel {gold_deprel.shape}" + ) + assert hypo_deprel.shape == hypo_head.shape, ( + f"Shape mismatch: hypo_deprel {hypo_deprel.shape} " + f"!= hypo_head {hypo_head.shape}" + ) + mask = gold_head != self.ignore_index + head_correct = (hypo_head == gold_head) & mask + if self.labeled: + # For LAS, both head and deprel must be correct. + deprel_correct = (hypo_deprel == gold_deprel) & mask + self.correct += torch.sum(head_correct & deprel_correct) + else: + # For UAS, only head needs to be correct. + self.correct += head_correct.sum() + self.total += mask.sum() + + def compute(self) -> torch.Tensor: + if self.total == 0: + return torch.tensor(0.0, device=self.device) + return self.correct.float() / self.total.float() + + +class UnlabeledAttachmentScore(AttachmentScore): + """Unlabeled Attachment Score (UAS). + + Computes the percentage of tokens with correct head assignment, + ignoring the dependency relation. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, labeled=False, **kwargs) + + +class LabeledAttachmentScore(AttachmentScore): + """Labeled Attachment Score (LAS). + + Computes the percentage of tokens with both correct head assignment + and correct dependency relation. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, labeled=True, **kwargs) diff --git a/udtube/models.py b/udtube/models.py index 04b1779..8e6985e 100644 --- a/udtube/models.py +++ b/udtube/models.py @@ -1,6 +1,6 @@ """The UDTube model.""" -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import lightning from lightning.pytorch import cli @@ -9,7 +9,7 @@ from torchmetrics import classification import wandb -from . import data, defaults, modules, special +from . import data, defaults, metrics, modules, special class UDTube(lightning.LightningModule): @@ -23,14 +23,13 @@ class UDTube(lightning.LightningModule): dropout: Dropout probability. encoder: Name of the Hugging Face model used to tokenize and encode. pooling_layers: Number of layers to use to compute the embedding. - reverse_edits: By default, lemmatization rules use reverse-edit - scripts, which are appropriate for predominantly suffixal - languages. When working with a predominantly prefixal language, - disable this by setting this to False. + arc_mlp_size: Size of the arc MLP for dependency parsing. + deprel_mlp_size: Size of the deprel MLP for dependency parsing. use_upos: Enables the universal POS tagging task. use_xpos: Enables the language-specific POS tagging task. use_lemma: Enables the lemmatization task. use_feats: Enables the morphological feature tagging task. + use_parse: Enables the dependenchy parsing task. """ encoder: modules.UDTubeEncoder @@ -41,26 +40,33 @@ class UDTube(lightning.LightningModule): xpos_accuracy: Optional[classification.MulticlassAccuracy] lemma_accuracy: Optional[classification.MulticlassAccuracy] feats_accuracy: Optional[classification.MulticlassAccuracy] + unlabeled_score: Optional[metrics.UnlabeledAttachmentScore] + labeled_score: Optional[metrics.LabeledAttachmentScore] def __init__( self, + *, dropout: float = defaults.DROPOUT, encoder: str = defaults.ENCODER, pooling_layers: int = defaults.POOLING_LAYERS, - reverse_edits: bool = defaults.REVERSE_EDITS, + arc_mlp_size: int = defaults.ARC_MLP_SIZE, + deprel_mlp_size: int = defaults.DEPREL_MLP_SIZE, use_upos: bool = defaults.USE_UPOS, use_xpos: bool = defaults.USE_XPOS, use_lemma: bool = defaults.USE_LEMMA, use_feats: bool = defaults.USE_FEATS, - *, + use_parse: bool = defaults.USE_PARSE, + # Optimization. encoder_optimizer: cli.OptimizerCallable = defaults.OPTIMIZER, encoder_scheduler: cli.LRSchedulerCallable = defaults.SCHEDULER, classifier_optimizer: cli.OptimizerCallable = defaults.OPTIMIZER, classifier_scheduler: cli.LRSchedulerCallable = defaults.SCHEDULER, + # Dummy values. upos_out_size: int = 2, # Dummy value filled in via link. xpos_out_size: int = 2, # Dummy value filled in via link. lemma_out_size: int = 2, # Dummy value filled in via link. feats_out_size: int = 2, # Dummy value filled in via link. + deprel_out_size: int = 2, # Dummy value filled in via link. ): super().__init__() # See what this disables here: @@ -69,14 +75,21 @@ def __init__( self.encoder = modules.UDTubeEncoder(dropout, encoder, pooling_layers) self.classifier = modules.UDTubeClassifier( self.encoder.hidden_size, - use_upos, - use_xpos, - use_lemma, - use_feats, + dropout=dropout, + arc_mlp_size=arc_mlp_size, + deprel_mlp_size=deprel_mlp_size, upos_out_size=upos_out_size, xpos_out_size=xpos_out_size, lemma_out_size=lemma_out_size, feats_out_size=feats_out_size, + arc_mlp_size=arc_mlp_size, + deprel_mlp_size=deprel_mlp_size, + deprel_out_size=deprel_out_size, + use_upos=use_upos, + use_xpos=use_xpos, + use_lemma=use_lemma, + use_feats=use_feats, + use_parse=use_parse, ) self.loss_func = nn.CrossEntropyLoss(ignore_index=special.PAD_IDX) self.upos_accuracy = ( @@ -91,6 +104,16 @@ def __init__( self.feats_accuracy = ( self._make_accuracy(feats_out_size) if use_feats else None ) + self.unlabeled_score = ( + metrics.UnlabeledAttachmentScore(ignore_index=special.PAD_IDX) + if use_parse + else None + ) + self.labeled_score = ( + metrics.LabeledAttachmentScore(ignore_index=special.PAD_IDX) + if use_parse + else None + ) self.encoder_optimizer = encoder_optimizer self.encoder_scheduler = encoder_scheduler self.classifier_optimizer = classifier_optimizer @@ -119,11 +142,16 @@ def use_lemma(self) -> bool: def use_feats(self) -> bool: return self.classifier.use_feats + @property + def use_parse(self) -> bool: + return self.classifier.use_parse + def forward( self, batch: data.Batch, - ) -> data.Logits: - return self.classifier(self.encoder(batch)) + ) -> Tuple[data.Logits, torch.Tensor]: + encoding, mask = self.encoder(batch) + return self.classifier(encoding, mask), mask def configure_optimizers( self, @@ -168,7 +196,9 @@ def on_fit_start(self) -> None: wandb.define_metric("val_upos_accuracy", summary="max") wandb.define_metric("val_xpos_accuracy", summary="max") - def predict_step(self, batch: data.Batch, batch_idx: int) -> data.Logits: + def predict_step( + self, batch: data.Batch, batch_idx: int + ) -> Tuple[data.Logits, torch.Tensor]: return self(batch) def training_step( @@ -178,7 +208,7 @@ def training_step( ) -> None: for optimizer in self.optimizers(): optimizer.zero_grad() - logits = self(batch) + logits, _ = self(batch) loss = self._log_loss(logits, batch, "train") self.manual_backward(loss) for optimizer in self.optimizers(): @@ -196,29 +226,29 @@ def on_train_epoch_end(self) -> None: scheduler.step() def on_validation_epoch_start(self) -> None: - self._reset_accuracies() + self._reset_metrics() def validation_step( self, batch: data.Batch, batch_idx: int, ) -> None: - logits = self(batch) + logits, mask = self(batch) self._log_loss(logits, batch, "val") - self._update_accuracies(logits, batch) + self._update_metrics(logits, batch, mask) def on_validation_epoch_end(self) -> None: - self._log_accuracies_epoch_end("val") + self._log_metrics_epoch_end("val") def on_test_step_epoch_start(self) -> None: - self._reset_accuracies() + self._reset_metrics() def test_step(self, batch: data.Batch, batch_idx: int) -> None: - logits = self(batch) - self._update_accuracies(logits, batch) + logits, mask = self(batch) + self._update_metrics(logits, batch, mask) def on_test_epoch_end(self) -> None: - self._log_accuracies_epoch_end("test") + self._log_metrics_epoch_end("test") def _log_loss( self, logits: data.Logits, batch: data.Batch, subset: str @@ -226,16 +256,23 @@ def _log_loss( losses = [] if self.use_upos: losses.append(self.loss_func(logits.upos, batch.upos)) - self.upos_accuracy.update(logits.upos, batch.upos) if self.use_xpos: losses.append(self.loss_func(logits.xpos, batch.xpos)) - self.xpos_accuracy.update(logits.xpos, batch.xpos) if self.use_lemma: losses.append(self.loss_func(logits.lemma, batch.lemma)) - self.lemma_accuracy.update(logits.lemma, batch.lemma) if self.use_feats: losses.append(self.loss_func(logits.feats, batch.feats)) - self.feats_accuracy.update(logits.feats, batch.feats) + if self.use_parse: + head_loss, deprel_loss = self.classifier.parse_head.compute_loss( + logits.head, + batch.head, + logits.deprel, + batch.deprel, + ) + # TODO(kbg): maybe something more sophisticated or general is + # called for here; test later. + losses.append(head_loss) + losses.append(deprel_loss) loss = torch.sum(torch.stack(losses)) if not self.trainer.sanity_checking: self.log( @@ -250,7 +287,7 @@ def _log_loss( # We can use the returned loss to step the optimizers. return loss - def _reset_accuracies(self) -> None: + def _reset_metrics(self) -> None: if self.use_upos: self.upos_accuracy.reset() if self.use_xpos: @@ -259,9 +296,12 @@ def _reset_accuracies(self) -> None: self.lemma_accuracy.reset() if self.use_feats: self.feats_accuracy.reset() + if self.use_parse: + self.unlabeled_score.reset() + self.labeled_score.reset() - def _update_accuracies( - self, logits: data.Logits, batch: data.Batch + def _update_metrics( + self, logits: data.Logits, batch: data.Batch, mask: torch.Tensor ) -> None: if self.use_upos: self.upos_accuracy.update(logits.upos, batch.upos) @@ -271,8 +311,14 @@ def _update_accuracies( self.lemma_accuracy.update(logits.lemma, batch.lemma) if self.use_feats: self.feats_accuracy.update(logits.feats, batch.feats) + if self.use_parse: + head, deprel = self.classifier.parse_head.decode( + logits.head, logits.deprel, mask + ) + self.unlabeled_score.update(head, batch.head) + self.labeled_score.update(head, batch.head, deprel, batch.deprel) - def _log_accuracies_epoch_end(self, subset: str) -> None: + def _log_metrics_epoch_end(self, subset: str) -> None: if self.use_upos: self.log( f"{subset}_upos_accuracy", @@ -305,3 +351,18 @@ def _log_accuracies_epoch_end(self, subset: str) -> None: logger=True, prog_bar=True, ) + if self.use_parse: + self.log( + f"{subset}_unlabeled_score", + self.unlabeled_score.compute(), + on_epoch=True, + logger=True, + prog_bar=True, + ) + self.log( + f"{subset}_labeled_score", + self.labeled_score.compute(), + on_epoch=True, + logger=True, + prog_bar=True, + ) diff --git a/udtube/modules.py b/udtube/modules.py index 6868879..3c2729f 100644 --- a/udtube/modules.py +++ b/udtube/modules.py @@ -5,7 +5,7 @@ or tags) of a sentence in the batch. """ -from typing import List, Optional +from typing import List, Optional, Tuple import lightning import tokenizers @@ -13,7 +13,7 @@ from torch import nn import transformers -from . import data, defaults, encoders +from . import data, defaults, encoders, parser class Error(Exception): @@ -55,7 +55,7 @@ def hidden_size(self) -> int: def forward( self, batch: data.Batch, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """Computes the contextual word-level encoding. This discards over-long sequences (if necessary), computes the subword @@ -66,7 +66,7 @@ def forward( batch: a data batch. Returns: - A contextual word-level encoding. + A contextual word-level encoding and a matching mask. """ # We move these manually. output = self.encoder( @@ -85,8 +85,7 @@ def forward( # Applies dropout. x = self.dropout_layer(x) # Maps from subword embeddings to word-level embeddings. - x = self._group_embeddings(x, batch.encodings) - return x + return self._group_embeddings(x, batch.encodings) def _group_embeddings( self, @@ -104,15 +103,16 @@ def _group_embeddings( encodings: the tokenizer encodings. Returns: - The re-pooled embeddings tensor. + The re-pooled word embeddings tensor and a matching word-level + boolean mask. """ new_sentence_embeddings = [] + lengths = [] for sentence_encodings, sentence_embeddings in zip( encodings, embeddings ): - # This looks like an overly elaborate loop that could be a list - # comprehension, but this is much faster. indices = [] + num_words = 0 i = 0 while i < len(sentence_encodings.word_ids): word_id = sentence_encodings.word_ids[i] @@ -121,6 +121,7 @@ def _group_embeddings( break pair = sentence_encodings.word_to_tokens(word_id) indices.append(pair) + num_words += 1 # Fast-forwards to the start of the next word. i = pair[-1] # For each span of subwords, combine via mean and then stack them. @@ -132,6 +133,7 @@ def _group_embeddings( ] ) ) + lengths.append(num_words) # Pads and stacks across sentences; the leading dimension is ragged # but `pad` cowardly refuses to pad non-trailing dimensions, so we # abuse transposition. @@ -139,7 +141,7 @@ def _group_embeddings( len(sentence_embedding) for sentence_embedding in new_sentence_embeddings ) - return torch.stack( + encoding = torch.stack( [ nn.functional.pad( sentence_embedding.mT, @@ -149,6 +151,16 @@ def _group_embeddings( for sentence_embedding in new_sentence_embeddings ] ).transpose(1, 2) + # Makes word-level mask. + mask = torch.zeros( + len(encodings), + pad_max, + device=self.device, + dtype=bool, + ) + for i, length in enumerate(lengths): + mask[i, :length] = True + return encoding, mask class UDTubeClassifier(lightning.LightningModule): @@ -156,39 +168,44 @@ class UDTubeClassifier(lightning.LightningModule): Args: hidden_size: size of the encoder hidden layer. - use_upos: enables the universal POS tagging task. - use_xpos: enables the language-specific POS tagging task. - use_lemma: enables the lemmatization task. - use_feats: enables the morphological feature tagging task. upos_out_size: number of UPOS classes; usually set automatically. xpos_out_size: number of XPOS classes; usually set automatically. lemma_out_size: number of LEMMA classes; usually set automatically. feats_out_size: number of FEATS classes; usually set automatically. + use_upos: enables the universal POS tagging task. + use_xpos: enables the language-specific POS tagging task. + use_lemma: enables the lemmatization task. + use_feats: enables the morphological feature tagging task. + use_parse: enables the dependenchy parsing task. """ - upos_head: Optional[nn.Sequential] - xpos_head: Optional[nn.Sequential] - lemma_head: Optional[nn.Sequential] - feats_head: Optional[nn.Sequential] + upos_head: Optional[nn.Linear] + xpos_head: Optional[nn.Linear] + lemma_head: Optional[nn.Linear] + feats_head: Optional[nn.Linear] + parse_head: Optional[parser.BiaffineParser] def __init__( self, hidden_size: int, - use_upos: bool = defaults.USE_UPOS, - use_xpos: bool = defaults.USE_XPOS, - use_lemma: bool = defaults.USE_LEMMA, - use_feats: bool = defaults.USE_FEATS, *, + dropout: float = defaults.DROPOUT, + arc_mlp_size: int = defaults.ARC_MLP_SIZE, + deprel_mlp_size: int = defaults.DEPREL_MLP_SIZE, # `2` is a dummy value here; it will be set by the dataset object. upos_out_size: int = 2, xpos_out_size: int = 2, lemma_out_size: int = 2, feats_out_size: int = 2, - # Optimization and LR scheduling. - **kwargs, + deprel_out_size: int = 2, + use_upos: bool = defaults.USE_UPOS, + use_xpos: bool = defaults.USE_XPOS, + use_lemma: bool = defaults.USE_LEMMA, + use_feats: bool = defaults.USE_FEATS, + use_parse: bool = defaults.USE_PARSE, ): super().__init__() - if not any([use_upos, use_xpos, use_lemma, use_feats]): + if not any([use_upos, use_xpos, use_lemma, use_feats, use_parse]): raise Error("No classification heads enabled") self.upos_head = ( nn.Linear(hidden_size, upos_out_size) if use_upos else None @@ -202,6 +219,17 @@ def __init__( self.feats_head = ( nn.Linear(hidden_size, feats_out_size) if use_feats else None ) + self.parse_head = ( + parser.BiaffineParser( + hidden_size, + arc_mlp_size, + deprel_mlp_size, + deprel_out_size, + dropout, + ) + if use_parse + else None + ) # Properties. @@ -221,9 +249,15 @@ def use_lemma(self) -> bool: def use_feats(self) -> bool: return self.feats_head is not None + @property + def use_parse(self) -> bool: + return self.parse_head is not None + # Forward pass. - def forward(self, encodings: torch.Tensor) -> data.Logits: + def forward( + self, encodings: torch.Tensor, mask: torch.Tensor + ) -> data.Logits: """Computes logits for each of the classification heads. This takes the contextual word encodings and then computes the logits @@ -232,30 +266,27 @@ def forward(self, encodings: torch.Tensor) -> data.Logits: transpose to produce this shape. Args: - encodings: the contextual word + encodings: word-level encoding tensor. + mask: word-level mask. Returns: - A contextual word-level encoding. + Logit tensors for all the active tasks. """ + if self.use_upos: + upos = self.upos_head(encodings).transpose(1, 2) + if self.use_xpos: + xpos = self.xpos_head(encodings).transpose(1, 2) + if self.use_lemma: + lemma = self.lemma_head(encodings).transpose(1, 2) + if self.use_feats: + feats = self.feats_head(encodings).transpose(1, 2) + if self.use_parse: + head, deprel = self.parse_head(encodings, mask) return data.Logits( - upos=( - self.upos_head(encodings).transpose(1, 2) - if self.use_upos - else None - ), - xpos=( - self.xpos_head(encodings).transpose(1, 2) - if self.use_xpos - else None - ), - lemma=( - self.lemma_head(encodings).transpose(1, 2) - if self.use_lemma - else None - ), - feats=( - self.feats_head(encodings).transpose(1, 2) - if self.use_feats - else None - ), + upos=upos if self.use_upos else None, + xpos=xpos if self.use_xpos else None, + lemma=lemma if self.use_lemma else None, + feats=feats if self.use_feats else None, + head=head if self.use_parse else None, + deprel=deprel if self.use_parse else None, ) diff --git a/udtube/parser.py b/udtube/parser.py new file mode 100644 index 0000000..d003a02 --- /dev/null +++ b/udtube/parser.py @@ -0,0 +1,305 @@ +"""Biaffine attention dependency parser. + +Based on: + + Dozat, T., and Manning, C. D. 2017. Deep biaffine attention for dependency + parsing. In ICLR. +""" + +from typing import Tuple + +import torch +from torch import nn + +from . import defaults, special + + +class BiaffineAttention(nn.Module): + r"""Biaffine attention mechanism for scoring head-dependent pairs. + + This implements the biaffine transformation: + + score(i, j) = h_j^T U h_i + (h_j \oplus h_i)^T w + b + + where h_i is the dependent representation and h_j is the head + representation. + + Args: + head_size: Size of head representation. + dep_size: Size of dependency representation. + out_size: Output dimension; use 1 for head scores and the number of + unique depependency relations for deprel scores. + """ + + head_size: int + dep_size: int + out_size: int + weight: nn.Parameter + + def __init__( + self, + head_size: int, + dep_size: int, + out_size: int = 1, + ): + super().__init__() + self.head_size = head_size + self.dep_size = dep_size + self.out_size = out_size + self.weight = nn.Parameter( + torch.zeros(self.out_size, self.head_size + 1, self.dep_size + 1) + ) + nn.init.xavier_uniform_(self.weight) + + def forward( + self, + head: torch.Tensor, + dep: torch.Tensor, + ) -> torch.Tensor: + """Computes biaffine attention scores. + + Args: + head: Head representations. + dep: Dependent representations. + + Returns: + Scores for each dependent position, scores for all possible heads. + """ + assert head.shape[0] == dep.shape[0], "Batch size mismatch" + assert head.shape[1] == dep.shape[1], "Sequence length mismatch" + assert ( + head.shape[2] == self.head_size + ), f"Head size mismatch: {head.shape[2]} != {self.head_size}" + assert ( + dep.shape[2] == self.dep_size + ), f"Dep size mismatch: {dep.shape[2]} != {self.dep_size}" + # FIXME =2? + head = torch.cat((head, torch.ones_like(head[..., :1])), dim=-1) + # FIXME =2? + dep = torch.cat((dep, torch.ones_like(dep[..., :1])), dim=-1) + dep_weight = torch.einsum("bld,odh->blho", dep, self.weight) + return torch.einsum("bsh,bdho->bdso", head, dep_weight) + + +class BiaffineParser(nn.Module): + """Biaffine parser for dependency arc and deprel prediction. + + This takes the encoder outputs and predicts: + + * Head indices for each token. + * Dependency dependency relations for each arc. + + Following Dozat & Manning, we apply separate MLPs to reduce dimensionality + before the biaffine classifiers. + + Head data is interpreted as indices, but these indices can collide with + the 0 used for padding. We therefore shift and unshift the data to avoid + this collision. + + Args: + arc_mlp_size: Hidden layer size for arc MLP. + deprel_mlp_size: Hidden layer size for deprel MLP. + dropout: Dropout probability for MLP layers + """ + + arc_head_mlp: nn.Module + arc_dep_mlp: nn.Module + arc_deprel_head_mlp: nn.Module + arc_deprel_dep_mlp: nn.Module + arc_attention: BiaffineAttention + deprel_attention: BiaffineAttention + loss_func: nn.CrossEntropyLoss + + def __init__( + self, + hidden_size, + arc_mlp_size: int = defaults.ARC_MLP_SIZE, + deprel_mlp_size: int = defaults.DEPREL_MLP_SIZE, + num_deprel: int = 2, # Dummy value filled in via link. + dropout: float = defaults.DROPOUT, + ): + super().__init__() + self.arc_head_mlp = self._make_mlp(hidden_size, arc_mlp_size, dropout) + self.arc_dep_mlp = self._make_mlp(hidden_size, arc_mlp_size, dropout) + self.deprel_head_mlp = self._make_mlp( + hidden_size, deprel_mlp_size, dropout + ) + self.deprel_dep_mlp = self._make_mlp( + hidden_size, deprel_mlp_size, dropout + ) + self.arc_attention = BiaffineAttention( + arc_mlp_size, + arc_mlp_size, + 1, + ) + self.deprel_attention = BiaffineAttention( + deprel_mlp_size, + deprel_mlp_size, + num_deprel, + ) + self.loss_func = nn.CrossEntropyLoss(ignore_index=special.PAD_IDX) + + @staticmethod + def _make_mlp( + input_size: int, hidden_size: int, dropout: float + ) -> nn.Module: + """Build a single-layer MLP with ReLU activation and dropout. + + Args: + input_size: Input size. + hidden_size: Hidden/output size. + dropout: Dropout probability. + + Returns: + A sequential MLP module. + """ + return nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.ReLU(), + nn.Dropout(dropout), + ) + + @staticmethod + def _shift_head(head: torch.Tensor) -> torch.Tensor: + """Converts indices to internal representation.""" + return torch.where( + head == special.PAD_IDX, + head, + head + special.OFFSET, + ) + + @staticmethod + def _unshift_head(head: torch.Tensor) -> torch.Tensor: + """Converts internal representation to indices.""" + return torch.where( + head == special.PAD_IDX, + head, + head - special.OFFSET, + ) + + def forward( + self, + encodings: torch.Tensor, + mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for dependency parsing. + + Args: + encodings: Encoder output. + mask: Attention mask of shape N x L. + + Returns: + A (arc logits, deprel logits) tuple. + """ + batch_size = encodings.size(0) + length = encodings.size(1) + arc_head = self.arc_head_mlp(encodings) + arc_dep = self.arc_dep_mlp(encodings) + deprel_head = self.deprel_head_mlp(encodings) + deprel_dep = self.deprel_dep_mlp(encodings) + # FIXME indices. + arc_logits = self.arc_attention(arc_head, arc_dep).squeeze(-1) + deprel_logits = self.deprel_attention(deprel_head, deprel_dep) + arc_mask = mask.unsqueeze(1) + arc_logits.masked_fill_(~arc_mask, defaults.NEG_EPSILON) + arc_mask = arc_mask.unsqueeze(-1) + deprel_logits.masked_fill_(~arc_mask, defaults.NEG_EPSILON) + assert arc_logits.shape == ( + batch_size, + length, + length, + ), f"Arc logits shape mismatch: {arc_logits.shape}" + assert deprel_logits.shape == ( + batch_size, + length, + length, + self.deprel_attention.out_size, + ), f"Deprel logits shape mismatch: {deprel_logits.shape}" + return arc_logits, deprel_logits + + def compute_loss( + self, + head_logits: torch.Tensor, + gold_head: torch.Tensor, + deprel_logits: torch.Tensor, + gold_deprel: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute head and deprel cross-entropy losses. + + Following Dozat & Manning, deprel prediction loss is conditioned on + the gold heads. + + Standard practice is to weigh these to combine them; we return them + separately to allow for more possibilities downstream. + + Args: + head_logits: Head scores. + gold_head: Gold head indices. + deprel_logits: dependency relation scores. + gold_deprel: Gold dependency relations. + + Returns: + The two losses. + """ + gold_head = self._unshift_head(gold_head) + head_loss = self.loss_func( + head_logits.reshape(-1, head_logits.size(-1)), + gold_head.reshape(-1), + ) + length = deprel_logits.size(1) + num_deprel = deprel_logits.size(3) + gold_head_expanded = ( + gold_head.unsqueeze(-1) + .unsqueeze(-1) + .expand(-1, length, 1, num_deprel) + ) + # Selects the appropriate deprel logits. + selected_deprel_logits = torch.gather( + deprel_logits, + dim=2, + index=gold_head_expanded, + ).squeeze(2) + # TODO: consider having the caller pass in the loss function object. + deprel_loss = self.loss_func( + selected_deprel_logits.reshape(-1, num_deprel), + gold_deprel.reshape(-1), + ) + return head_loss, deprel_loss + + def decode( + self, + head_logits: torch.Tensor, + deprel_logits: torch.Tensor, + mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Decode head and deprel predictions from logits. + + This uses greedy decoding. + + Args: + head_logits: Head scores of shape N x L x L. + deprel_logits: Label scores of shape N x L x L x C + mask: Attention mask of shape N x L. + + Returns: + Predicted head and deprel. + """ + # FIXME indices. + pred_head = head_logits.argmax(dim=-1) + pred_head.masked_fill_(~mask, special.PAD_IDX) + batch_size = deprel_logits.size(0) + length = deprel_logits.size(1) + num_deprel = deprel_logits.size(3) + pred_head_expanded = ( + pred_head.unsqueeze(-1) + .unsqueeze(-1) + .expand(batch_size, length, 1, num_deprel) + ) + selected_deprel_logits = torch.gather( + deprel_logits, dim=2, index=pred_head_expanded + ) + # FIXME indices. + pred_deprel = selected_deprel_logits.squeeze(2).argmax(dim=-1) + pred_head = self._shift_head(pred_head) + pred_deprel.masked_fill_(~mask, special.PAD_IDX) + return pred_head, pred_deprel diff --git a/udtube/special.py b/udtube/special.py index fdcbd21..0c629a9 100644 --- a/udtube/special.py +++ b/udtube/special.py @@ -6,3 +6,5 @@ PAD_IDX = 0 UNK_IDX = 1 + +OFFSET = len(SPECIAL)