From 13019fd80bb1e32f702ac913a3e82aad156581d6 Mon Sep 17 00:00:00 2001 From: Ricardo Rei Date: Fri, 17 Apr 2020 15:44:40 +0100 Subject: [PATCH] open source version --- .gitignore | 131 ++++++ README.md | 73 ++++ caption/__init__.py | 0 caption/__main__.py | 86 ++++ caption/data_loader.py | 119 ++++++ caption/datasets/__init__.py | 2 + caption/datasets/language_modeling.py | 27 ++ caption/datasets/lazy_dataset.py | 33 ++ caption/datasets/sequence_tagging.py | 64 +++ caption/models/__init__.py | 30 ++ caption/models/activations.py | 100 +++++ caption/models/caption_base_model.py | 388 ++++++++++++++++++ caption/models/encoders/__init__.py | 22 + caption/models/encoders/bert.py | 86 ++++ caption/models/encoders/encoder_base.py | 88 ++++ caption/models/encoders/hf_roberta.py | 84 ++++ caption/models/encoders/roberta.py | 103 +++++ caption/models/encoders/xlm_roberta.py | 116 ++++++ caption/models/language_models/__init__.py | 1 + caption/models/language_models/lm_base.py | 191 +++++++++ caption/models/language_models/masked_lm.py | 67 +++ caption/models/metrics.py | 102 +++++ caption/models/scalar_mix.py | 131 ++++++ caption/models/taggers/__init__.py | 4 + caption/models/taggers/tagger_base.py | 283 +++++++++++++ caption/models/taggers/transformer_tagger.py | 217 ++++++++++ caption/models/utils.py | 78 ++++ caption/optimizers/__init__.py | 27 ++ caption/optimizers/adam.py | 81 ++++ caption/optimizers/adamax.py | 66 +++ caption/optimizers/adamw.py | 80 ++++ caption/optimizers/optim_args.py | 39 ++ caption/optimizers/radam.py | 188 +++++++++ caption/schedulers/__init__.py | 29 ++ caption/schedulers/constant_lr.py | 40 ++ caption/schedulers/linear_warmup.py | 77 ++++ caption/schedulers/scheduler_args.py | 34 ++ caption/schedulers/warmup_constant.py | 57 +++ caption/testing.py | 86 ++++ caption/tokenizers/__init__.py | 11 + caption/tokenizers/bert_tokenizer.py | 58 +++ caption/tokenizers/hf_roberta_tokenizer.py | 69 ++++ caption/tokenizers/roberta_tokenizer.py | 64 +++ caption/tokenizers/tokenizer_base.py | 131 ++++++ caption/training.py | 320 +++++++++++++++ caption/utils.py | 62 +++ requirements.txt | 13 + setup.py | 18 + tests/__init__.py | 0 tests/lightning_models/__init__.py | 0 .../lightning_models/test_taggers/__init__.py | 0 .../test_taggers/test_transformer_tagger.py | 125 ++++++ tests/test_metrics.py | 62 +++ tests/unit/__init__.py | 0 tests/unit/test_encoders/__init__.py | 0 tests/unit/test_encoders/test_bert_encoder.py | 60 +++ tests/unit/test_optimizers/__init__.py | 0 tests/unit/test_optimizers/test_adam.py | 22 + tests/unit/test_optimizers/test_adamax.py | 22 + tests/unit/test_optimizers/test_adamw.py | 22 + tests/unit/test_optimizers/test_radam.py | 22 + tests/unit/test_schedulers/__init__.py | 0 .../test_constant_scheduler.py | 23 ++ .../test_linear_warmup_scheduler.py | 23 ++ .../test_warmup_constant_scheduler.py | 23 ++ tests/unit/test_tokenizers/__init__.py | 0 .../test_tokenizers/test_bert_tokenizer.py | 114 +++++ .../test_tokenizers/test_roberta_tokenizer.py | 114 +++++ 68 files changed, 4808 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 caption/__init__.py create mode 100644 caption/__main__.py create mode 100644 caption/data_loader.py create mode 100644 caption/datasets/__init__.py create mode 100644 caption/datasets/language_modeling.py create mode 100644 caption/datasets/lazy_dataset.py create mode 100644 caption/datasets/sequence_tagging.py create mode 100644 caption/models/__init__.py create mode 100644 caption/models/activations.py create mode 100644 caption/models/caption_base_model.py create mode 100644 caption/models/encoders/__init__.py create mode 100644 caption/models/encoders/bert.py create mode 100644 caption/models/encoders/encoder_base.py create mode 100644 caption/models/encoders/hf_roberta.py create mode 100644 caption/models/encoders/roberta.py create mode 100644 caption/models/encoders/xlm_roberta.py create mode 100644 caption/models/language_models/__init__.py create mode 100644 caption/models/language_models/lm_base.py create mode 100644 caption/models/language_models/masked_lm.py create mode 100644 caption/models/metrics.py create mode 100644 caption/models/scalar_mix.py create mode 100644 caption/models/taggers/__init__.py create mode 100644 caption/models/taggers/tagger_base.py create mode 100644 caption/models/taggers/transformer_tagger.py create mode 100644 caption/models/utils.py create mode 100644 caption/optimizers/__init__.py create mode 100644 caption/optimizers/adam.py create mode 100644 caption/optimizers/adamax.py create mode 100644 caption/optimizers/adamw.py create mode 100644 caption/optimizers/optim_args.py create mode 100644 caption/optimizers/radam.py create mode 100644 caption/schedulers/__init__.py create mode 100644 caption/schedulers/constant_lr.py create mode 100644 caption/schedulers/linear_warmup.py create mode 100644 caption/schedulers/scheduler_args.py create mode 100644 caption/schedulers/warmup_constant.py create mode 100644 caption/testing.py create mode 100644 caption/tokenizers/__init__.py create mode 100644 caption/tokenizers/bert_tokenizer.py create mode 100644 caption/tokenizers/hf_roberta_tokenizer.py create mode 100644 caption/tokenizers/roberta_tokenizer.py create mode 100644 caption/tokenizers/tokenizer_base.py create mode 100644 caption/training.py create mode 100644 caption/utils.py create mode 100644 requirements.txt create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/lightning_models/__init__.py create mode 100644 tests/lightning_models/test_taggers/__init__.py create mode 100644 tests/lightning_models/test_taggers/test_transformer_tagger.py create mode 100644 tests/test_metrics.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_encoders/__init__.py create mode 100644 tests/unit/test_encoders/test_bert_encoder.py create mode 100644 tests/unit/test_optimizers/__init__.py create mode 100644 tests/unit/test_optimizers/test_adam.py create mode 100644 tests/unit/test_optimizers/test_adamax.py create mode 100644 tests/unit/test_optimizers/test_adamw.py create mode 100644 tests/unit/test_optimizers/test_radam.py create mode 100644 tests/unit/test_schedulers/__init__.py create mode 100644 tests/unit/test_schedulers/test_constant_scheduler.py create mode 100644 tests/unit/test_schedulers/test_linear_warmup_scheduler.py create mode 100644 tests/unit/test_schedulers/test_warmup_constant_scheduler.py create mode 100644 tests/unit/test_tokenizers/__init__.py create mode 100644 tests/unit/test_tokenizers/test_bert_tokenizer.py create mode 100644 tests/unit/test_tokenizers/test_roberta_tokenizer.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1573498 --- /dev/null +++ b/.gitignore @@ -0,0 +1,131 @@ +*.pkl +*DS_Store +data/ +experiments/ +configs/ +._data + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..6c3b80c --- /dev/null +++ b/README.md @@ -0,0 +1,73 @@ +# Capitalisation And PuncTuatION (CAPTION) +> PT2020 Transcription project. + +In this repository, we explore different strategies for automatic transcription enrichment for ASR data which includes tasks such as automatic capitalization (truecasing) and punctuation recovery. + +## Model architecture: + +![base_model](https://i.ibb.co/sm3P2Bq/Screenshot-2020-04-14-at-16-19-10.png) + +### Available Encoders: +- [BERT](https://arxiv.org/abs/1810.04805) +- [RoBERTa](https://arxiv.org/abs/1907.11692) +- [XLM-RoBERTa](https://arxiv.org/pdf/1911.02116.pdf) + +## Requirements: + +This project uses Python >3.6 + +Create a virtual env with (outside the project folder): + +```bash +virtualenv -p python3.6 caption-env +``` + +Activate venv: +```bash +source caption-env/bin/activate +``` + +Finally, run: +```bash +python setup.py install +``` + +If you wish to make changes into the code run: +```bash +pip install -r requirements.txt +pip install -e . +``` + +## Getting Started: + +### Train: +```bash +python caption train -f {your_config_file}.yaml +``` + +### Testing: +```bash +python caption test \ + --checkpoint=some/path/to/your/checkpoint.ckpt \ + --test_csv=path/to/your/testset.csv +``` + +### Tensorboard: + +Launch tensorboard with: +```bash +tensorboard --logdir="experiments/lightning_logs/" +``` + +If you are running experiments in a remote server you can forward your localhost to the server localhost.. + +### How to run the tests: +In order to run the toolkit tests you must run the following command: + +```bash +cd tests +python -m unittest +``` + +### Code Style: +To make sure all the code follows the same style we use [Black](https://github.com/psf/black). diff --git a/caption/__init__.py b/caption/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/caption/__main__.py b/caption/__main__.py new file mode 100644 index 0000000..d016dbc --- /dev/null +++ b/caption/__main__.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +import logging +from data_loader import add_data_args + +from models import add_model_args, build_model +from optimizers import add_optimizer_args +from schedulers import add_scheduler_args +from test_tube import HyperOptArgumentParser +from testing import run_testing, setup_testing +from torchnlp.random import set_seed +from training import add_trainer_specific_args, setup_training +from utils import get_main_args_from_yaml, load_yaml_args + +log = logging.getLogger("Shell") +logging.basicConfig(level=logging.INFO) + + +def run_training_pipeline(parser): + parser.add_argument( + "-f", "--config", default=False, type=str, help="Path to a YAML config file." + ) + parser.add_argument( + "--optimizer", + default=False, + type=str, + help="Optimizer to be used during training.", + ) + parser.add_argument( + "--scheduler", + default=False, + type=str, + help="LR scheduler to be used during training.", + ) + parser.add_argument( + "--model", + default=False, + type=str, + help="The estimator architecture we we wish to use.", + ) + args, _ = parser.parse_known_args() + + if not args.optimizer and not args.scheduler and not args.model: + optimizer, scheduler, model = get_main_args_from_yaml(args) + else: + optimizer = args.optimizer + scheduler = args.scheduler + model = args.model + + parser = add_optimizer_args(parser, optimizer) + parser = add_scheduler_args(parser, scheduler) + parser = add_model_args(parser, model) + parser = add_trainer_specific_args(parser) + hparams = load_yaml_args(parser=parser, log=log) + + set_seed(hparams.seed) + model = build_model(hparams) + trainer = setup_training(hparams) + + if hparams.load_weights: + model.load_weights(hparams.load_weights) + + log.info(f"{model.__class__.__name__} train starting:") + trainer.fit(model) + + +def run_testing_pipeline(parser): + parser = add_data_args(parser) + parser.add_argument( + "--checkpoint", default=None, help="Checkpoint file path.", + ) + hparams = parser.parse_args() + run_testing(hparams) + + +if __name__ == "__main__": + parser = HyperOptArgumentParser( + strategy="random_search", description="CAPTION project", add_help=True + ) + parser.add_argument( + "pipeline", choices=["train", "test"], help="train a model or test.", + ) + args, _ = parser.parse_known_args() + if args.pipeline == "test": + run_testing_pipeline(parser) + else: + run_training_pipeline(parser) diff --git a/caption/data_loader.py b/caption/data_loader.py new file mode 100644 index 0000000..e5fad63 --- /dev/null +++ b/caption/data_loader.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +import pandas as pd + +from test_tube import HyperOptArgumentParser +from torchnlp.datasets.dataset import Dataset + + +def add_data_args(parser: HyperOptArgumentParser) -> HyperOptArgumentParser: + """ + Functions that parses dataset specific arguments/hyperparameters. + :param hparams: HyperOptArgumentParser obj. + + Returns: + - updated parser + """ + parser.add_argument( + "--data_type", + default="csv", + type=str, + help="The type of the file containing the training/dev/test data.", + choices=["csv"], + ) + parser.add_argument( + "--train_path", + default="data/dummy_train.csv", + type=str, + help="Path to the file containing the train data.", + ) + parser.add_argument( + "--dev_path", + default="data/dummy_test.csv", + type=str, + help="Path to the file containing the dev data.", + ) + parser.add_argument( + "--test_path", + default="data/dummy_test.csv", + type=str, + help="Path to the file containing the test data.", + ) + parser.add_argument( + "--loader_workers", + default=0, + type=int, + help="How many subprocesses to use for data loading. 0 means that \ + the data will be loaded in the main process.", + ) + return parser + + +def file_to_list(filename: str) -> list: + """ Reads a file and returns a list with the cotent of each line. """ + with open(filename, "r") as filehandler: + content = [line.strip() for line in filehandler.readlines()] + return content + + +def collate_lists(source: list, target: list, tags: list) -> dict: + """ For each line of the source, target and tags tags creates a dictionary. """ + collated_dataset = [] + for i in range(len(source)): + collated_dataset.append( + {"source": str(source[i]), "target": str(target[i]), "tags": str(tags[i]),} + ) + return collated_dataset + + +def load_from_csv(hparams: HyperOptArgumentParser, train=True, val=True, test=True): + """ + This dataset loader is used for loading: + Source text, target text and tags for training, development and testing. + + :param hparams: HyperOptArgumentParser obj containg the path to the data files. + :param train: flag to return the train set. + :param val: flag to return the validation set. + :param test: flag to return the test set. + + Returns: + - Training Dataset, Development Dataset, Testing Dataset + """ + + def load_dataset(path): + df = pd.read_csv(path) + source = list(df.source) + target = list(df.target) + tags = list(df.tags) + assert len(source) == len(target) == len(tags) + return Dataset(collate_lists(source, target, tags)) + + func_out = [] + if train: + func_out.append(load_dataset(hparams.train_path)) + if val: + func_out.append(load_dataset(hparams.dev_path)) + if test: + func_out.append(load_dataset(hparams.test_path)) + return tuple(func_out) + + +def text_recovery_dataset( + hparams: HyperOptArgumentParser, train=True, val=True, test=True +): + """ + This dataset loader is used for loading: + Source text, arget text and tags for training, development and testing. + + This task consists in automatic capitalization and punctuation recovery. + + :param hparams: HyperOptArgumentParser obj containg the path to the data files. + + Returns: + - Training Dataset, Development Dataset, Testing Dataset + """ + if hparams.data_type == "csv": + return load_from_csv(hparams, train, val, test) + else: + raise Exception( + "Invalid configs data_type. Only csv and txt files are supported." + ) diff --git a/caption/datasets/__init__.py b/caption/datasets/__init__.py new file mode 100644 index 0000000..a3fb0c6 --- /dev/null +++ b/caption/datasets/__init__.py @@ -0,0 +1,2 @@ +from .sequence_tagging import sequence_tagging_dataset +from .language_modeling import load_mlm_dataset diff --git a/caption/datasets/language_modeling.py b/caption/datasets/language_modeling.py new file mode 100644 index 0000000..9dcb197 --- /dev/null +++ b/caption/datasets/language_modeling.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +from test_tube import HyperOptArgumentParser + +from .lazy_dataset import LineByLineTextDataset + + +def load_mlm_dataset(hparams: HyperOptArgumentParser, train=True, val=True, test=True): + """ + This dataset loader is used for loading data for language modeling. + + :param hparams: HyperOptArgumentParser obj containg the path to the data files. + :param train: flag to return the train set. + :param val: flag to return the validation set. + :param test: flag to return the test set. + + Returns: + - Training Dataset, Development Dataset, Testing Dataset + """ + func_out = [] + if train: + func_out.append(LineByLineTextDataset(hparams.train_path)) + if val: + func_out.append(LineByLineTextDataset(hparams.dev_path)) + if test: + func_out.append(LineByLineTextDataset(hparams.test_path)) + + return tuple(func_out) diff --git a/caption/datasets/lazy_dataset.py b/caption/datasets/lazy_dataset.py new file mode 100644 index 0000000..a15e4a5 --- /dev/null +++ b/caption/datasets/lazy_dataset.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +import linecache +import os +import torch + + +class LineByLineTextDataset(torch.utils.data.Dataset): + """ + Dataset object that reads a txt file line by line. + + :param path: Path to the txt file containing our sentences. + """ + + def __init__(self, file_path: str): + assert os.path.isfile(file_path) + + with open(file_path, "r") as fp: + self.num_lines = sum(1 for line in fp) - 1 + + self.file_path = file_path + + def __len__(self): + return self.num_lines + 1 + + def __getitem__(self, idx): + if idx > self.num_lines: + raise ValueError( + "Trying to access index {} in a dataset with size={}".format( + idx, self.num_lines + ) + ) + line = linecache.getline(self.file_path, idx + 1) + return {"text": line.strip()} diff --git a/caption/datasets/sequence_tagging.py b/caption/datasets/sequence_tagging.py new file mode 100644 index 0000000..3978556 --- /dev/null +++ b/caption/datasets/sequence_tagging.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +import pandas as pd + +from test_tube import HyperOptArgumentParser + + +def collate_lists(text: list, tags: list) -> dict: + """ For each line of the text and tags creates a dictionary. """ + collated_dataset = [] + for i in range(len(text)): + collated_dataset.append( + {"text": str(text[i]), "tags": str(tags[i]),} + ) + return collated_dataset + + +def load_from_csv(hparams: HyperOptArgumentParser, train=True, val=True, test=True): + """ + Dataset loader function used for loading: + text and tags for training, development and testing. + + :param hparams: HyperOptArgumentParser obj containg the path to the data files. + :param train: flag to return the train set. + :param val: flag to return the validation set. + :param test: flag to return the test set. + + Returns: + - Training Dataset, Development Dataset, Testing Dataset + """ + + def load_dataset(path): + df = pd.read_csv(path) + text = list(df.text) + tags = list(df.tags) + assert len(text) == len(tags) + return collate_lists(text, tags) + + func_out = [] + if train: + func_out.append(load_dataset(hparams.train_path)) + if val: + func_out.append(load_dataset(hparams.dev_path)) + if test: + func_out.append(load_dataset(hparams.test_path)) + return tuple(func_out) + + +def sequence_tagging_dataset( + hparams: HyperOptArgumentParser, train=True, val=True, test=True +): + """ + Function that loads a tagging dataset for automatic capitalization and punctuation recovery. + + :param hparams: HyperOptArgumentParser obj containg the path to the data files. + + Returns: + - Training Dataset, Development Dataset, Testing Dataset + """ + if hparams.data_type == "csv": + return load_from_csv(hparams, train, val, test) + else: + raise Exception( + "Invalid configs data_type. Only csv and txt files are supported." + ) diff --git a/caption/models/__init__.py b/caption/models/__init__.py new file mode 100644 index 0000000..f167452 --- /dev/null +++ b/caption/models/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +import logging + +import pytorch_lightning as ptl +import pandas as pd +import os + +from .taggers import TransformerTagger +from .language_models import MaskedLanguageModel + +str2model = { + "TransformerTagger": TransformerTagger, + "MaskedLanguageModel": MaskedLanguageModel, +} + + +def build_model(hparams) -> ptl.LightningModule: + """ + Function that builds an estimator model from the HyperOptArgumentParser + :param hparams: HyperOptArgumentParser + """ + return str2model[hparams.model](hparams) + + +def add_model_args(parser, model: str): + return str2model[model].add_model_specific_args(parser) + try: + return str2model[model].add_model_specific_args(parser) + except KeyError: + raise Exception(f"{model} is not a valid model type!") diff --git a/caption/models/activations.py b/caption/models/activations.py new file mode 100644 index 0000000..960972a --- /dev/null +++ b/caption/models/activations.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +import math + +import torch +from torch import nn + + +def build_activation(activation: str): + """ Builder function that returns a nn.module activation function. + + :param activation: string defining the name of the activation function. + + Activations available: + GELU, Swish + every native pytorch activation function. + """ + if hasattr(nn, activation): + return getattr(nn, activation)() + elif activation == "Swish": + return Swish() + elif activation == "GELU": + return GELU() + else: + raise Exception("{} invalid activation function.".format(activation)) + + +def swish(input): + """ + Applies Swish element-wise: A self-gated activation function + swish(x) = x * sigmoid(x) + """ + return input * torch.sigmoid(input) + + +class Swish(nn.Module): + """ + Applies the Swish function element-wise: + + Swish(x) = x * sigmoid(x) + + Shape: + - Input: (N, *) where * means, any number of additional + dimensions + - Output: (N, *), same shape as the input + + References: + - Related paper: + https://arxiv.org/pdf/1710.05941v1.pdf + """ + + def __init__(self): + """ + Init method. + """ + super().__init__() # init the base class + + def forward(self, input): + """ + Forward pass of the function. + """ + return swish(input) + + +def gelu(x): + """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). + Also see https://arxiv.org/abs/1606.08415 + """ + return ( + 0.5 + * x + * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +class GELU(nn.Module): + """ + Applies the GELU function element-wise: + + GELU(x) = 0.5*(1 + tanh(√2/π) * (x + 0.044715 * x^3)) + + Shape: + - Input: (N, *) where * means, any number of additional + dimensions + - Output: (N, *), same shape as the input + + References: + - Related paper: + https://arxiv.org/pdf/1606.08415.pdf + """ + + def __init__(self): + """ + Init method. + """ + super().__init__() # init the base class + + def forward(self, input): + """ + Forward pass of the function. + """ + return gelu(input) diff --git a/caption/models/caption_base_model.py b/caption/models/caption_base_model.py new file mode 100644 index 0000000..e53fd7f --- /dev/null +++ b/caption/models/caption_base_model.py @@ -0,0 +1,388 @@ +# -*- coding: utf-8 -*- +r""" +CAPTION Model Base +============== + Abstract base class used to build new modules inside CAPTION. +""" +import json +import logging +import math + +import numpy as np +import torch +from torch.utils.data import DataLoader, RandomSampler, Subset +from tqdm import tqdm + +import pytorch_lightning as ptl +from caption.models.encoders import Encoder +from caption.optimizers import build_optimizer +from caption.schedulers import build_scheduler +from caption.tokenizers import TextEncoderBase +from test_tube import HyperOptArgumentParser + +torch.set_printoptions(precision=6) +log = logging.getLogger("Shell") + + +class CaptionModelBase(ptl.LightningModule): + """ + Caption Modules extend PyTorch Lightning with a common structure and interface + that will be shared across all modules (e.g. estimators, word_taggers, classifiers, etc..) + in this project . + + :param hparams: HyperOptArgumentParser containing the hyperparameters. + """ + + def __init__(self, hparams: HyperOptArgumentParser,) -> None: + super(CaptionModelBase, self).__init__() + self.hparams = hparams + self._encoder = self._build_encoder(hparams) + + # Model initialization + self._build_model() + + # Loss criterion initialization. + self._build_loss() + + # The encoder always starts in a frozen state. + if hparams.nr_frozen_epochs > 0: + self._frozen = True + self.freeze_encoder() + else: + self._frozen = False + + self.nr_frozen_epochs = hparams.nr_frozen_epochs + + # training helpers. + self._pbar = None # used during training to produce a loading bar. + + # used during hyperparameter search only. + self._best = {"val_loss": math.inf} + self._best[self.hparams.monitor] = ( + -math.inf if self.hparams.metric_mode == "max" else math.inf + ) + + def _build_loss(self): + """ Initializes the loss function/s. """ + raise NotImplementedError + + def _build_model(self) -> ptl.LightningModule: + """ + Initializes the estimator architecture. + """ + raise NotImplementedError + + def _build_encoder( + self, hparams: HyperOptArgumentParser + ) -> (Encoder, TextEncoderBase): + """ + Initializes the encoder. + """ + raise NotImplementedError + + def _retrieve_dataset(self, data_hparams, train=True, val=True, test=True): + """ Retrieves task specific dataset """ + raise NotImplementedError + + @property + def encoder(self): + """ Model encoding layer. """ + return self._encoder + + def freeze_encoder(self) -> None: + """ Freezes the encoder layer. """ + self.encoder.freeze() + + def unfreeze_encoder(self) -> None: + """ un-freezes the encoder layer. """ + if self._frozen: + log.info(f"\n-- Encoder model fine-tuning") + self.encoder.unfreeze() + self._frozen = False + + def predict(self, sample: dict) -> dict: + """ Function that runs a model prediction, + :param sample: dictionary with expected model sequences. + You can also pass a list of dictionaries to predict an entire batch. + + Return: Dictionary with model outputs + """ + raise NotImplementedError + + def forward(self, *args, **kwargs) -> dict: + """ + PyTorch Forward. + Return: Dictionary with model outputs to be passed to the loss function. + """ + raise NotImplementedError + + def _compute_loss(self, model_out: dict, targets: dict) -> torch.tensor: + """ + Computes Loss value according to a loss function. + :param model_out: model specific output. + :param targets: Target score values [batch_size] + """ + raise NotImplementedError + + def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict): + """ + Function that prepares a sample to input the model. + :param sample: list of dictionaries. + + Returns: + - dictionary with the expected model inputs. + - dictionary with the expected target values (e.g. HTER score). + """ + raise NotImplementedError + + def configure_optimizers(self): + """ Function for setting up the optimizers and the schedulers to be used during training. + + Returns: + - List with as many optimizers as we need + - List with the respective schedulers. + """ + optimizer = build_optimizer(self.parameters(), self.hparams) + scheduler = build_scheduler(optimizer, self.hparams) + return [optimizer], [scheduler] + + def _compute_metrics(self, outputs: list) -> dict: + """ + Private function that computes metrics of interest based on the list of outputs + you defined in validation_step. + """ + raise NotImplementedError + + def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: + """ Runs one training step. This usually consists in the forward function followed + by the loss function. + :param batch: The output of your dataloader. + :param batch_nb: Integer displaying which batch this is + + Returns: + - dictionary containing the loss and the metrics to be added to the lightning logger. + """ + batch_input, batch_target = batch + batch_prediction = self.forward(**batch_input) + loss_value = self._compute_loss(batch_prediction, batch_target) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp or self.trainer.use_ddp2: + loss_value = loss_value.unsqueeze(0) + + return {"loss": loss_value} + + def validation_step(self, batch: tuple, batch_nb: int, dataloader_idx: int) -> dict: + """ Similar to the training step but with the model in eval mode. + + Returns: + - dictionary passed to the validation_end function. + """ + batch_input, batch_target = batch + batch_prediction = self.forward(**batch_input) + loss_value = self._compute_loss(batch_prediction, batch_target) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp or self.trainer.use_ddp2: + loss_value = loss_value.unsqueeze(0) + + return { + "val_loss": loss_value, + "val_prediction": batch_prediction, + "val_target": batch_target, + } + + def validation_epoch_end(self, outputs: list) -> dict: + """ Function that takes as input a list of dictionaries returned by the validation_step + function and measures the model performance accross the entire validation set. + + Returns: + - Dictionary with metrics to be added to the lightning logger. + """ + train_batches, val_batches = outputs + avg_train_loss = torch.stack([x["val_loss"] for x in train_batches]).mean() + avg_val_loss = torch.stack([x["val_loss"] for x in val_batches]).mean() + + train_metrics = self._compute_metrics(train_batches) + metrics = self._compute_metrics(val_batches) + + log.info(f"-- Avg Train loss {avg_train_loss:.4}") + log.info("-- Train metrics:\n{}".format(json.dumps(train_metrics, indent=1))) + + log.info(f"-- Avg Dev loss {avg_val_loss:.4}") + log.info("-- Dev metrics:\n{}".format(json.dumps(metrics, indent=1))) + + # Store internally the best pearson result achieved. + if ( + metrics[self.hparams.monitor] > self._best[self.hparams.monitor] + and self.hparams.metric_mode == "max" + ): + + self._best = { + self.hparams.monitor: metrics[self.hparams.monitor], + "val_loss": avg_val_loss.item(), + } + elif ( + metrics[self.hparams.monitor] < self._best[self.hparams.monitor] + and self.hparams.metric_mode == "min" + ): + + self._best = { + self.hparams.monitor: metrics[self.hparams.monitor], + "val_loss": avg_val_loss.item(), + } + + return { + "log": {**metrics, "val_loss": avg_val_loss, "train_loss": avg_train_loss} + } + + def test_step(self, batch: list, batch_nb: int, *args, **kwargs) -> dict: + """ Redirects to validation step. """ + pass + + def test_epoch_end(self, outputs: list) -> dict: + """ Redirects to validation end. """ + pass + + def prepare_data(self) -> None: + """Data preparation function called before training by Lightning""" + ( + self._train_dataset, + self._val_dataset, + self._test_dataset, + ) = self._retrieve_dataset(self.hparams) + + train_subset = np.random.choice( + a=len(self._train_dataset), + size=int(len(self._train_dataset) * self.hparams.train_val_percent_check), + ) + self._train_subset = Subset(self._train_dataset, train_subset) + + def train_dataloader(self) -> DataLoader: + """ Function that loads the train set. """ + return DataLoader( + dataset=self._train_dataset, + # sampler=RandomSampler(self._train_dataset), + batch_size=self.hparams.batch_size, + collate_fn=self.prepare_sample, + num_workers=self.hparams.loader_workers, + ) + + def val_dataloader(self) -> DataLoader: + """ Function that loads the validation set. """ + return [ + DataLoader( + dataset=self._train_subset, + batch_size=self.hparams.batch_size, + collate_fn=self.prepare_sample, + num_workers=self.hparams.loader_workers, + ), + DataLoader( + dataset=self._val_dataset, + batch_size=self.hparams.batch_size, + collate_fn=self.prepare_sample, + num_workers=self.hparams.loader_workers, + ), + ] + + def test_dataloader(self) -> DataLoader: + """ Function that loads the validation set. """ + return DataLoader( + dataset=self._test_dataset, + batch_size=self.hparams.batch_size, + collate_fn=self.prepare_sample, + num_workers=self.hparams.loader_workers, + ) + + def on_epoch_end(self): + """ Pytorch lightning hook """ + if self.current_epoch + 1 >= self.nr_frozen_epochs and self._frozen: + self.unfreeze_encoder() + self._frozen = False + + def on_epoch_start(self): + """ Pytorch lightning hook """ + if self.current_epoch == 0 and not self._frozen: + log.info(f"\n-- Encoder model fine-tuning.") + + if not self.hparams.disable_progress_bar: + nr_batches = math.ceil( + (len(self._train_dataset) / self.hparams.batch_size) + * self.hparams.train_percent_check + ) + self._pbar = tqdm(total=nr_batches, unit="batch") + + def on_batch_start(self, batch): + """ Pytorch lightning hook """ + if not self.hparams.disable_progress_bar: + self._pbar.update(1) + + def on_pre_performance_check(self): + """ Pytorch lightning hook """ + if ( + not self.hparams.disable_progress_bar + and self.hparams.val_check_interval >= 1 + ): + # closes tqdm progress bar before updating the shell + self._pbar.close() if self._pbar else None + + def load_weights(self, checkpoint: str) -> None: + """ Function that loads the weights from a given checkpoint file. + Note: + If the checkpoint model architecture is different then `self`, only + the common parts will be loaded. + + :param checkpoint: Path to the checkpoint containing the weights to be loaded. + """ + log.info(f"loading model weights from {checkpoint}.") + checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage,) + pretrained_dict = checkpoint["state_dict"] + model_dict = self.state_dict() + + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + # 3. load the new state dict + self.load_state_dict(model_dict) + + # ------------------------------------ Arg parsing ------------------------------------------ + @staticmethod + def add_encoder_args(parser: HyperOptArgumentParser) -> HyperOptArgumentParser: + """ + Functions that parses Encoder specific arguments/hyperparameters. + :param hparams: HyperOptArgumentParser obj. + + Returns: + - updated parser + """ + raise NotImplementedError + + @staticmethod + def add_model_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ Parser for Estimator specific arguments/hyperparameters. + :param parser: HyperOptArgumentParser obj + + Returns: + - updated parser + """ + parser.opt_list( + "--nr_frozen_epochs", + default=0, + type=int, + help="Number of epochs we want to keep the encoder model frozen.", + tunable=True, + options=[0, 1, 2, 3, 4, 5], + ) + parser.add_argument( + "--disable_progress_bar", + default=False, + help=( + "By default the estimator class creates a progress bar during" + "training. Using this flag you can desable this behavior." + ), + action="store_true", + ) + return parser diff --git a/caption/models/encoders/__init__.py b/caption/models/encoders/__init__.py new file mode 100644 index 0000000..28a5d23 --- /dev/null +++ b/caption/models/encoders/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +from .bert import BERT +from .roberta import RoBERTa +from .xlm_roberta import XLMRoBERTa +from .encoder_base import Encoder +from .hf_roberta import HuggingFaceRoBERTa + + +str2encoder = { + "BERT": BERT, + "RoBERTa": RoBERTa, + "XLM-RoBERTa": XLMRoBERTa, + "HF-RoBERTa": HuggingFaceRoBERTa, # legacy +} + +__all__ = [ + "Encoder", + "BERT", + "RoBERTa", + "XLMRoBERTa", + "HuggingFaceRoBERTa", # legacy +] diff --git a/caption/models/encoders/bert.py b/caption/models/encoders/bert.py new file mode 100644 index 0000000..1dfa06c --- /dev/null +++ b/caption/models/encoders/bert.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +r""" +Hugging Face BERT implementation. +============== + BERT model from hugging face transformers repo. +""" +import torch +from transformers import BertModel, BertForMaskedLM + +from caption.models.encoders.encoder_base import Encoder +from caption.tokenizers import BERTTextEncoder +from test_tube import HyperOptArgumentParser +from torchnlp.utils import lengths_to_mask + + +class BERT(Encoder): + """ + BERT encoder. + + :param tokenizer: BERT text encoder. + :param hparams: HyperOptArgumentParser obj. + :param lm_head: If true the language model head from the pretrain model is saved. + + Check the available models here: + https://huggingface.co/transformers/pretrained_models.html + """ + + def __init__( + self, + tokenizer: BERTTextEncoder, + hparams: HyperOptArgumentParser, + lm_head: bool = False, + ) -> None: + super().__init__(768 if "base" in hparams.pretrained_model else 1024, tokenizer) + self._n_layers = 13 if "base" in hparams.pretrained_model else 25 + self.padding_idx = self.tokenizer.padding_index + + if not lm_head: + self.model = BertModel.from_pretrained( + hparams.pretrained_model, output_hidden_states=True + ) + else: + mlm_model = BertForMaskedLM.from_pretrained( + hparams.pretrained_model, output_hidden_states=True + ) + self.model = mlm_model.bert + self.lm_head = mlm_model.cls + + @classmethod + def from_pretrained(cls, hparams: HyperOptArgumentParser, lm_head: bool = False): + """ Function that loads a pretrained BERT encoder. + :param hparams: HyperOptArgumentParser obj. + + Returns: + - BERT Encoder model + """ + tokenizer = BERTTextEncoder(model=hparams.pretrained_model) + model = BERT(tokenizer=tokenizer, hparams=hparams, lm_head=lm_head) + return model + + def forward(self, tokens: torch.tensor, lengths: torch.tensor, **kwargs) -> dict: + """ + Encodes a batch of sequences. + :param tokens: Torch tensor with the input sequences [batch_size x seq_len]. + :param lengths: Torch tensor with the length of each sequence [seq_len]. + + Returns: + - 'sentemb': tensor [batch_size x 1024] with the sentence encoding. + - 'wordemb': tensor [batch_size x seq_len x 1024] with the word level embeddings. + - 'mask': torch.Tensor [seq_len x batch_size] + - 'all_layers': List with the word_embeddings returned by each layer. + - 'extra': tuple with the last_hidden_state [batch_size x seq_len x hidden_size], + the pooler_output representing the entire sentence and the word embeddings for + all BERT layers (list of tensors [batch_size x seq_len x hidden_size]) + + """ + mask = lengths_to_mask(lengths, device=tokens.device) + # Run BERT model. + last_hidden_states, pooler_output, all_layers = self.model(tokens, mask) + return { + "sentemb": pooler_output, + "wordemb": last_hidden_states, + "all_layers": all_layers, + "mask": mask, + "extra": (last_hidden_states, pooler_output, all_layers), + } diff --git a/caption/models/encoders/encoder_base.py b/caption/models/encoders/encoder_base.py new file mode 100644 index 0000000..5c6b6ca --- /dev/null +++ b/caption/models/encoders/encoder_base.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +r""" +Encoder Model Base +============== + Abstract base class used to build new pretrained Encoder models. +""" +import os + +import torch +import torch.nn as nn + +from caption.tokenizers import TextEncoderBase +from test_tube import HyperOptArgumentParser + + +class Encoder(nn.Module): + """ Base class for an encoder model. + + :param output_units: Number of output features that will be passed to the Estimator. + """ + + def __init__( + self, output_units: int, tokenizer: TextEncoderBase, lm_head: bool = False + ) -> None: + super().__init__() + self.output_units = output_units + self.tokenizer = tokenizer + + @property + def num_layers(self): + """ Number of model layers available. """ + return self._n_layers + + @classmethod + def from_pretrained(cls, hparams: HyperOptArgumentParser, lm_head: bool = False): + """ Function that loads a pretrained encoder and the respective tokenizer. + + Returns: + - Encoder model + """ + raise NotImplementedError + + def prepare_sample( + self, sample: list, trackpos: bool = True + ) -> (torch.tensor, torch.tensor): + """ Receives a list of strings and applies model specific tokenization and vectorization.""" + if not trackpos: + tokens, lengths = self.tokenizer.batch_encode(sample) + return {"tokens": tokens, "lengths": lengths} + + ( + tokens, + lengths, + word_boundaries, + word_lengths, + ) = self.tokenizer.batch_encode_trackpos(sample) + return { + "tokens": tokens, + "lengths": lengths, + "word_boundaries": word_boundaries, + "word_lengths": word_lengths, + } + + def freeze(self) -> None: + """ Frezees the entire encoder network. """ + for param in self.parameters(): + param.requires_grad = False + + def unfreeze(self) -> None: + """ Unfrezees the entire encoder network. """ + for param in self.parameters(): + param.requires_grad = True + + def forward(self, tokens: torch.tensor, lengths: torch.tensor, **kwargs) -> dict: + """ + Encodes a batch of sequences. + + :param tokens: Torch tensor with the input sequences [batch_size x seq_len]. + :param lengths: Torch tensor with the lenght of each sequence [seq_len]. + + Returns: + - 'sentemb': tensor [batch_size x output_units] with the sentence encoding. + - 'wordemb': tensor [batch_size x seq_len x output_units] with the word level embeddings. + - 'mask': input mask. + - 'all_layers': List with the word_embeddings returned by each layer. + - 'extra': model specific outputs. + """ + raise NotImplementedError diff --git a/caption/models/encoders/hf_roberta.py b/caption/models/encoders/hf_roberta.py new file mode 100644 index 0000000..b833ffa --- /dev/null +++ b/caption/models/encoders/hf_roberta.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +r""" +Hugging Face RoBERTa implementation. +============== + RoBERTa model from hugging face transformers repo. +""" +import torch +from transformers import RobertaModel, RobertaForMaskedLM + +from caption.models.encoders.encoder_base import Encoder +from caption.tokenizers import HfRoBERTaTextEncoder +from test_tube import HyperOptArgumentParser +from torchnlp.utils import lengths_to_mask + + +class HuggingFaceRoBERTa(Encoder): + """ + Hugging Face RoBERTa encoder. + + :param tokenizer: RoBERTa text encoder. + :param hparams: HyperOptArgumentParser obj. + :param lm_head: If true the language model head from the pretrain model is saved. + + Check the available models here: + https://huggingface.co/transformers/pretrained_models.html + """ + + def __init__( + self, + tokenizer: HfRoBERTaTextEncoder, + hparams: HyperOptArgumentParser, + lm_head: bool = False, + ) -> None: + super().__init__(768 if "base" in hparams.pretrained_model else 1024, tokenizer) + self._n_layers = 13 if "base" in hparams.pretrained_model else 25 + self.padding_idx = self.tokenizer.padding_index + if not lm_head: + self.model = RobertaModel.from_pretrained( + hparams.pretrained_model, output_hidden_states=True + ) + else: + mlm_model = RobertaForMaskedLM.from_pretrained( + hparams.pretrained_model, output_hidden_states=True + ) + self.model = mlm_model.roberta + self.lm_head = mlm_model.lm_head + + @classmethod + def from_pretrained(cls, hparams: HyperOptArgumentParser, lm_head: bool = False): + """ Function that loads a pretrained RoBERTa encoder. + :param hparams: HyperOptArgumentParser obj. + + Returns: + - RoBERTa Encoder model from hugging face + """ + tokenizer = HfRoBERTaTextEncoder(model=hparams.pretrained_model) + model = RoBERTa(tokenizer=tokenizer, hparams=hparams, lm_head=lm_head) + return model + + def forward(self, tokens: torch.tensor, lengths: torch.tensor, **kwargs) -> dict: + """ + Encodes a batch of sequences. + :param tokens: Torch tensor with the input sequences [batch_size x seq_len]. + :param lengths: Torch tensor with the length of each sequence [seq_len]. + + Returns: + - 'sentemb': tensor [batch_size x 1024] with the sentence encoding. + - 'wordemb': tensor [batch_size x seq_len x 1024] with the word level embeddings. + - 'all_layers': List with the word_embeddings returned by each layer. + - 'mask': torch.Tensor [seq_len x batch_size] + - 'extra': tuple with the last_hidden_state [batch_size x seq_len x hidden_size], + the pooler_output representing the entire sentence and the word embeddings for + all XLM-R layers (list of tensors [batch_size x seq_len x hidden_size]) + """ + mask = lengths_to_mask(lengths, device=tokens.device) + # Run RoBERTa model. + last_hidden_states, pooler_output, all_layers = self.model(tokens, mask) + return { + "sentemb": pooler_output, + "wordemb": last_hidden_states, + "all_layers": all_layers, + "mask": mask, + "extra": (last_hidden_states, pooler_output, all_layers), + } diff --git a/caption/models/encoders/roberta.py b/caption/models/encoders/roberta.py new file mode 100644 index 0000000..b247dae --- /dev/null +++ b/caption/models/encoders/roberta.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +r""" +Fairseq original RoBERTa implementation. +============== + Original RoBERTa model. +""" +import os + +import torch + +from caption.models.encoders.encoder_base import Encoder +from caption.tokenizers import RoBERTaTextEncoder +from fairseq.models.roberta import RobertaModel +from test_tube import HyperOptArgumentParser +from torchnlp.download import download_file_maybe_extract +from torchnlp.utils import lengths_to_mask + + +ROBERTA_LARGE_URL = "https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz" +ROBERTA_LARGE_MODEL_NAME = "roberta.large/model.pt" + +ROBERTA_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz" +ROBERTA_BASE_MODEL_NAME = "roberta.base/model.pt" + + +class RoBERTa(Encoder): + """ + RoBERTa encoder from Fairseq. + + :param roberta: RoBERTa model to be used. + :param tokenizer: RoBERTa model tokenizer to be used. + :param hparams: HyperOptArgumentParser obj. + :param lm_head: If true the language model head from the pretrain model is saved. + """ + + def __init__( + self, + roberta: RobertaModel, + tokenizer: RoBERTaTextEncoder, + hparams: HyperOptArgumentParser, + lm_head: bool = False, + ) -> None: + super().__init__(768 if "base" in hparams.pretrained_model else 1024, tokenizer) + self._n_layers = 13 if "base" in hparams.pretrained_model else 25 + self.model = roberta + self.lm_head = self.model.model.decoder.lm_head if lm_head else None + + @classmethod + def from_pretrained(cls, hparams: HyperOptArgumentParser, lm_head: bool = False): + if not os.path.exists("pretrained/"): + os.mkdir("pretrained/") + + pretrained_model = hparams.pretrained_model + if pretrained_model == "roberta.base": + download_file_maybe_extract( + ROBERTA_BASE_URL, + directory="pretrained", + check_files=[ROBERTA_BASE_MODEL_NAME], + ) + + elif pretrained_model == "roberta.large": + download_file_maybe_extract( + ROBERTA_LARGE_URL, + directory="pretrained", + check_files=[ROBERTA_LARGE_MODEL_NAME], + ) + else: + raise Exception(f"{pretrained_model} is an invalid RoBERTa model.") + + roberta = RobertaModel.from_pretrained( + "pretrained/" + pretrained_model, checkpoint_file="model.pt" + ) + roberta.eval() + tokenizer = RoBERTaTextEncoder( + roberta.encode, roberta.task.source_dictionary.__dict__["indices"] + ) + return RoBERTa( + roberta=roberta, tokenizer=tokenizer, hparams=hparams, lm_head=lm_head + ) + + def forward(self, tokens: torch.tensor, lengths: torch.tensor, **kwargs) -> dict: + """ + Encodes a batch of sequences. + :param tokens: Torch tensor with the input sequences [batch_size x seq_len]. + :param lengths: Torch tensor with the length of each sequence [seq_len]. + + Returns: + - 'sentemb': tensor [batch_size x 1024] with the sentence encoding. + - 'wordemb': tensor [batch_size x seq_len x 1024] with the word level embeddings. + - 'all_layers': List with the word_embeddings returned by each layer. + - 'mask': torch.Tensor [seq_len x batch_size] + - 'extra': tuple with all XLM-R layers (list of tensors [batch_size x seq_len x hidden_size]) + """ + mask = lengths_to_mask(lengths, device=tokens.device) + # Run RoBERTa model. + all_layers = self.model.extract_features(tokens, return_all_hiddens=True) + return { + "sentemb": all_layers[-1][:, 0, :], + "wordemb": all_layers[-1], + "all_layers": all_layers, + "mask": mask, + "extra": (all_layers), + } diff --git a/caption/models/encoders/xlm_roberta.py b/caption/models/encoders/xlm_roberta.py new file mode 100644 index 0000000..4a3303c --- /dev/null +++ b/caption/models/encoders/xlm_roberta.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +import os + +import torch + +from caption.models.encoders.encoder_base import Encoder +from caption.tokenizers import RoBERTaTextEncoder +from fairseq.models.roberta import XLMRModel +from test_tube import HyperOptArgumentParser +from torchnlp.download import download_file_maybe_extract +from torchnlp.utils import lengths_to_mask + +XLMR_LARGE_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz" +XLMR_LARGE_MODEL_NAME = "xlmr.large/model.pt" + +XLMR_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz" +XLMR_BASE_MODEL_NAME = "xlmr.base/model.pt" + +XLMR_LARGE_V0_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz" +XLMR_LARGE_V0_MODEL_NAME = "xlmr.large.v0/model.pt" + +XLMR_BASE_V0_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz" +XLMR_BASE_V0_MODEL_NAME = "xlmr.base.v0/model.pt" + + +class XLMRoBERTa(Encoder): + """ + XLM-RoBERTa encoder from Fairseq. + + :param xlmr: XLM-R model to be used. + :param tokenizer: XLM-R model tokenizer to be used. + :param hparams: HyperOptArgumentParser obj. + :param lm_head: If true the language model head from the pretrain model is saved. + """ + + def __init__( + self, + xlmr: XLMRModel, + tokenizer: RoBERTaTextEncoder, + hparams: HyperOptArgumentParser, + lm_head: bool = False, + ) -> None: + super().__init__(768 if "base" in hparams.pretrained_model else 1024, tokenizer) + self._n_layers = 13 if "base" in hparams.pretrained_model else 25 + self.xlmr = xlmr + self.lm_head = self.xlmr.model.decoder.lm_head if lm_head else None + + @classmethod + def from_pretrained(cls, hparams: HyperOptArgumentParser, lm_head: bool = False): + if not os.path.exists("pretrained/"): + os.mkdir("pretrained/") + + pretrained_model = hparams.pretrained_model + if pretrained_model == "xlmr.base": + download_file_maybe_extract( + XLMR_BASE_URL, + directory="pretrained", + check_files=[XLMR_BASE_MODEL_NAME], + ) + + elif pretrained_model == "xlmr.large": + download_file_maybe_extract( + XLMR_LARGE_URL, + directory="pretrained", + check_files=[XLMR_LARGE_MODEL_NAME], + ) + elif pretrained_model == "xlmr.base.v0": + download_file_maybe_extract( + XLMR_BASE_V0_URL, + directory="pretrained", + check_files=[XLMR_BASE_V0_MODEL_NAME], + ) + + elif pretrained_model == "xlmr.large.v0": + download_file_maybe_extract( + XLMR_LARGE_V0_URL, + directory="pretrained", + check_files=[XLMR_LARGE_V0_MODEL_NAME], + ) + else: + raise Exception(f"{pretrained_model} is an invalid XLM-R model.") + + xlmr = XLMRModel.from_pretrained( + "pretrained/" + pretrained_model, checkpoint_file="model.pt" + ) + xlmr.eval() + tokenizer = RoBERTaTextEncoder( + xlmr.encode, xlmr.task.source_dictionary.__dict__["indices"] + ) + return XLMRoBERTa( + xlmr=xlmr, tokenizer=tokenizer, hparams=hparams, lm_head=lm_head + ) + + def forward(self, tokens: torch.tensor, lengths: torch.tensor, **kwargs) -> dict: + """ + Encodes a batch of sequences. + :param tokens: Torch tensor with the input sequences [batch_size x seq_len]. + :param lengths: Torch tensor with the length of each sequence [seq_len]. + + Returns: + - 'sentemb': tensor [batch_size x 1024] with the sentence encoding. + - 'wordemb': tensor [batch_size x seq_len x 1024] with the word level embeddings. + - 'all_layers': List with the word_embeddings returned by each layer. + - 'mask': torch.Tensor [seq_len x batch_size] + - 'extra': tuple with all XLM-R layers (list of tensors [batch_size x seq_len x hidden_size]) + """ + mask = lengths_to_mask(lengths, device=tokens.device) + # Run XLM-R model. + all_layers = self.xlmr.extract_features(tokens, return_all_hiddens=True) + return { + "sentemb": all_layers[-1][:, 0, :], + "wordemb": all_layers[-1], + "all_layers": all_layers, + "mask": mask, + "extra": (all_layers), + } diff --git a/caption/models/language_models/__init__.py b/caption/models/language_models/__init__.py new file mode 100644 index 0000000..4e4b144 --- /dev/null +++ b/caption/models/language_models/__init__.py @@ -0,0 +1 @@ +from .masked_lm import MaskedLanguageModel diff --git a/caption/models/language_models/lm_base.py b/caption/models/language_models/lm_base.py new file mode 100644 index 0000000..eb9bd3d --- /dev/null +++ b/caption/models/language_models/lm_base.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +r""" +Language Model Base +============== + Abstract base class used to build new language models + inside CAPTION. +""" +import json +import logging + +import torch +import torch.nn as nn + +from caption.datasets import load_mlm_dataset +from caption.models.caption_base_model import CaptionModelBase +from caption.models.encoders import Encoder, str2encoder +from test_tube import HyperOptArgumentParser + +torch.set_printoptions(precision=6) +log = logging.getLogger("Shell") + + +class LanguageModel(CaptionModelBase): + """ + Language Model base class used to fine-tune pretrained models such as RoBERTa. + + :param hparams: HyperOptArgumentParser containing the hyperparameters. + """ + + def __init__(self, hparams: HyperOptArgumentParser,) -> None: + super().__init__(hparams) + + def _build_loss(self): + """ Initializes the loss function/s. """ + if self.hparams.loss == "cross_entropy": + self.loss = nn.CrossEntropyLoss() + else: + raise Exception(f"{loss_func} is not a valid loss option.") + + def _retrieve_dataset(self, hparams, train=True, val=True, test=True): + """ Retrieves task specific dataset """ + return load_mlm_dataset(hparams, train, val, test) + + def _build_encoder(self, hparams: HyperOptArgumentParser) -> Encoder: + """ + Initializes the encoder. + """ + return str2encoder[self.hparams.encoder_model].from_pretrained( + hparams, lm_head=True + ) + + def predict(self, sample: dict) -> dict: + """ Function that runs a model prediction, + :param sample: dictionary with 'src', 'mt' and 'ref' + or a list containing several dictionaries with those keys + Return: Dictionary with model outputs + """ + raise NotImplementedError + + def _compute_loss(self, model_out: dict, targets: dict) -> torch.tensor: + """ + Computes Loss value according to a loss function. + :param model_out: model specific output. Must contain a key 'score' with + a tensor [batch_size x 1] with model predictions + :param targets: Target score values [batch_size] + """ + word_logits = model_out["scores"].view(-1, model_out["scores"].size(-1)) + masked_lm_labels = targets["lm_labels"] + return self.loss(word_logits, masked_lm_labels.view(-1)) + + def validation_step(self, batch: tuple, batch_nb: int, dataloader_idx: int) -> dict: + """ Overwrite validation step and to return only the loss.""" + batch_input, batch_target = batch + batch_prediction = self.forward(**batch_input) + loss_value = self._compute_loss(batch_prediction, batch_target) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp or self.trainer.use_ddp2: + loss_value = loss_value.unsqueeze(0) + + return {"val_loss": loss_value} + + def validation_epoch_end(self, outputs: list) -> dict: + """ Overwrite validation end to compute perplexity and skip _compute_metrics.""" + + def perplexity(loss_value, nr_batches): + return {"perplexity": torch.exp(loss_value / nr_batches).item()} + + train_batches, val_batches = outputs + train_loss = torch.stack([x["val_loss"] for x in train_batches]).mean() + val_loss = torch.stack([x["val_loss"] for x in val_batches]).mean() + log.info(f"-- Avg Train loss {train_loss:.4}") + log.info( + "-- Train metrics:\n{}".format( + json.dumps(perplexity(train_loss, len(train_batches)), indent=1) + ) + ) + metrics = perplexity(val_loss, len(val_batches)) + log.info(f"-- Avg Dev loss {val_loss:.4}") + log.info("-- Dev metrics:\n{}".format(json.dumps(metrics), indent=1)) + + # Store internally the best pearson result achieved. + if ( + metrics[self.hparams.monitor] < self._best[self.hparams.monitor] + and self.hparams.metric_mode == "min" + ): + self._best = { + self.hparams.monitor: metrics[self.hparams.monitor], + "val_loss": val_loss.item(), + } + + return {"log": {**metrics, "val_loss": val_loss, "train_loss": train_loss}} + + @staticmethod + def add_model_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ Function that adds shared arguments/hyperparameters across all + estimators. + + :param parser: HyperOptArgumentParser obj + + Returns: + - updated parser + """ + parser = super(LanguageModel, LanguageModel).add_model_specific_args(parser) + parser.add_argument( + "--loss", + default="cross_entropy", + type=str, + help="Loss function to be used.", + choices=["cross_entropy"], + ) + # Parameters for the Encoder model + parser.add_argument( + "--encoder_model", + default="RoBERTa", + type=str, + help="Encoder model to be used.", + choices=["BERT", "XLM-RoBERTa", "RoBERTa"], + ) + parser.add_argument( + "--pretrained_model", + default="roberta-base", + type=str, + help=( + "Encoder pretrained model to be used. " + "(e.g. roberta-base or roberta-large)" + ), + ) + parser.add_argument( + "--mlm_probability", + default=0.15, + type=float, + help="Ratio of tokens to mask for masked language modeling loss.", + ) + # Data Arguments + parser.add_argument( + "--data_type", + default="txt", + type=str, + help="The type of the file containing the training/val/test data.", + choices=["txt"], + ) + parser.add_argument( + "--loader_workers", + default=4, + type=int, + help="How many subprocesses to use for data loading. 0 means that \ + the data will be loaded in the main process.", + ) + # Data Arguments + parser.add_argument( + "--train_path", + default="data/WMT18/en-de/train.nmt.ref", + type=str, + help="Path to the file containing the train data.", + ) + parser.add_argument( + "--dev_path", + default="data/WMT18/en-de/dev.nmt.ref", + type=str, + help="Path to the file containing the dev data.", + ) + parser.add_argument( + "--test_path", + default="data/WMT18/en-de/dev.nmt.pe", + type=str, + help="Path to the file containing the test data.", + ) + return parser diff --git a/caption/models/language_models/masked_lm.py b/caption/models/language_models/masked_lm.py new file mode 100644 index 0000000..4ed8041 --- /dev/null +++ b/caption/models/language_models/masked_lm.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +r""" +Masked Language Model +============== + Model used to fine-tune encoder models such as RoBERTa and XLM-RoBERTa in a specific + domain. +""" +import torch +import torch.nn as nn + +from caption.models.language_models.lm_base import LanguageModel +from caption.models.utils import mask_tokens +from caption.optimizers import build_optimizer +from caption.schedulers import build_scheduler +from test_tube import HyperOptArgumentParser +from torchnlp.utils import collate_tensors + + +class MaskedLanguageModel(LanguageModel): + """ + Model used to pretrain encoder model such as BERT and XLM-R in with a + Masked Language Modeling objective. + + :param hparams: HyperOptArgumentParser containing the hyperparameters. + """ + + def __init__(self, hparams: HyperOptArgumentParser,) -> None: + super().__init__(hparams) + + def _build_model(self) -> LanguageModel: + """ + The Masked Language model head is already initialized by the encoder. + """ + pass + + def prepare_sample(self, sample: list) -> (dict, dict): + """ + Function that prepares a sample to input the model. + :param sample: list of dictionaries. + + Returns: + - dictionary with the expected model inputs. + - dictionary with the expected target values (e.g. HTER score). + """ + sample = collate_tensors(sample) + sample = self.encoder.prepare_sample(sample["text"], trackpos=False) + tokens, labels = mask_tokens( + sample["tokens"], self.encoder.tokenizer, self.hparams.mlm_probability, + ) + return {"tokens": tokens, "lengths": sample["lengths"]}, {"lm_labels": labels} + + def forward(self, tokens: torch.tensor, lengths: torch.tensor, **kwargs) -> dict: + """ + :param tokens: sequences [batch_size x src_seq_len] + :param lengths: sequence lengths [batch_size] + Return: Dictionary with model outputs to be passed to the loss function. + """ + # When using just one GPU this should not change behavior + # but when splitting batches across GPU the tokens have padding + # from the entire original batch + if self.trainer and self.trainer.use_dp and self.trainer.num_gpus > 1: + tokens = tokens[:, : lengths.max()] + + embeddings = self.encoder(tokens, lengths)["wordemb"] + return { + "scores": self.encoder.lm_head(embeddings), + } diff --git a/caption/models/metrics.py b/caption/models/metrics.py new file mode 100644 index 0000000..004e6a0 --- /dev/null +++ b/caption/models/metrics.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +import numpy as np +from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix + + +def classification_report( + y_pred: np.array, y: np.array, padding: int, labels: dict, ignore: int = -1 +) -> dict: + """ + Function that computes the F Score for all labels, the Macro-average F Score \ + and the slot error rate. + + :param y: Ground-truth labels. + :param y_pred: Model label predictions. + :param padding: Label padding value. + :param labels: dictionary with the name of each label (e.g. {'BAD': 0, 'OK': 1}) + :param ignore: By setting this value the F Score of this label will not be taken + into consideration when computing the Macro-average F Score. + """ + report = { + "slot_error_rate": slot_error_rate( + y, y_pred, padding, ignore=ignore if ignore >= 0 else None + ) + } + + cm = confusion_matrix(y_pred, y, padding) + index2label = {v: k for k, v in labels.items()} + fscore_avg = [] + tpos, fpos, fneg = [], [], [] + for i in range(len(labels)): + if i == padding: + continue + tp = cm[i][i] + fp = np.sum(cm[i, :]) - tp + fn = np.sum(cm[:, i]) - tp + f_score = fscore(tp, fp, fn) + report["{}_f1_score".format(index2label[i])] = f_score + + if i != ignore: + fscore_avg.append(f_score) + tpos.append(tp) + fpos.append(fp) + fneg.append(fn) + + report["macro_fscore"] = sum(fscore_avg) / len(fscore_avg) + report["micro_fscore"] = fscore(sum(tpos), sum(fpos), sum(fneg)) + return report + + +def precision(tp: int, fp: int, fn: int) -> float: + if tp + fp > 0: + return tp / (tp + fp) + return 0 + + +def recall(tp: int, fp: int, fn: int) -> float: + if tp + fn > 0: + return tp / (tp + fn) + return 0 + + +def fscore(tp: int, fp: int, fn: int) -> float: + p = precision(tp, fp, fn) + r = recall(tp, fp, fn) + if p + r > 0: + return 2 * (p * r) / (p + r) + return 0 + + +def confusion_matrix(y_pred: np.array, y: np.array, padding: int) -> np.array: + """ Function that creates a confusion matrix using the Wikipedia convention for the axis. + :param y_pred: predicted tags. + :param y: the ground-truth tags. + :param padding: padding index to be ignored. + + Returns: + - Confusion matrix for all the labels + padding label.""" + y_pred = np.ma.masked_array(data=y_pred, mask=(y == padding)).filled(padding) + return sklearn_confusion_matrix(y_pred, y) + + +def slot_error_rate( + y_true: np.ndarray, y_pred: np.ndarray, padding=None, ignore=None +) -> np.float64: + """ All classes associated with padding will be ignored in the evaluation. + :param y_true: Ground-truth labels. + :param y_pred: Model label predictions. + :param padding: Label padding value. + :param ignore: Label value corresponding to the majority label. (e.g. B in BIO tags) + Returns: + - np.float64 with slot error rate. + """ + pad_mask = 1 + ref_mask = 1 + if padding is not None: + pad_mask = y_true != padding + if ignore is not None: + ref_mask = y_true != ignore + + slots_ref = np.sum(ref_mask * pad_mask) + errors = np.sum((y_true != y_pred) * pad_mask) + return errors / np.maximum(slots_ref, 1) diff --git a/caption/models/scalar_mix.py b/caption/models/scalar_mix.py new file mode 100644 index 0000000..43ea3c8 --- /dev/null +++ b/caption/models/scalar_mix.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +import torch +from torch.nn import ParameterList, Parameter + + +class ScalarMixWithDropout(torch.nn.Module): + """ + Computes a parameterised scalar mixture of N tensors, 'mixture = gamma * sum(s_k * tensor_k)' + where 's = softmax(w)', with 'w' and 'gamma' scalar parameters. + + If 'do_layer_norm=True' then apply layer normalization to each tensor before weighting. + + If 'dropout > 0', then for each scalar weight, adjust its softmax weight mass to 0 with + the dropout probability (i.e., setting the unnormalized weight to -inf). This effectively + should redistribute dropped probability mass to all other weights. + + Original implementation: + - https://github.com/Hyperparticle/udify + """ + + def __init__( + self, + mixture_size: int, + do_layer_norm: bool = False, + initial_scalar_parameters: list = None, + trainable: bool = True, + dropout: float = None, + dropout_value: float = -1e20, + ) -> None: + super(ScalarMixWithDropout, self).__init__() + self.mixture_size = mixture_size + self.do_layer_norm = do_layer_norm + self.dropout = dropout + + if initial_scalar_parameters is None: + initial_scalar_parameters = [0.0] * mixture_size + elif len(initial_scalar_parameters) != mixture_size: + raise Exception( + "Length of initial_scalar_parameters {} differs \ + from mixture_size {}".format( + initial_scalar_parameters, mixture_size + ) + ) + + self.scalar_parameters = ParameterList( + [ + Parameter( + torch.FloatTensor([initial_scalar_parameters[i]]), + requires_grad=trainable, + ) + for i in range(mixture_size) + ] + ) + + self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable) + + if self.dropout: + dropout_mask = torch.zeros(len(self.scalar_parameters)) + dropout_fill = torch.empty(len(self.scalar_parameters)).fill_(dropout_value) + self.register_buffer("dropout_mask", dropout_mask) + self.register_buffer("dropout_fill", dropout_fill) + + def forward( + self, + tensors: list, # pylint: disable=arguments-differ + mask: torch.Tensor = None, + ) -> torch.Tensor: + """ + Compute a weighted average of the 'tensors'. The input tensors an be any shape + with at least two dimensions, but must all be the same shape. + When 'do_layer_norm=True', the 'mask' is required input. If the 'tensors' are + dimensioned '(dim_0, ..., dim_{n-1}, dim_n)', then the 'mask' is dimensioned + '(dim_0, ..., dim_{n-1})', as in the typical case with 'tensors' of shape + '(batch_size, timesteps, dim)' and 'mask' of shape '(batch_size, timesteps)'. + When 'do_layer_norm=False' the 'mask' is ignored. + """ + if len(tensors) != self.mixture_size: + raise Exception( + "{} tensors were passed, but the module was initialized to \ + mix {} tensors.".format( + len(tensors), self.mixture_size + ) + ) + + def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked): + tensor_masked = tensor * broadcast_mask + mean = torch.sum(tensor_masked) / num_elements_not_masked + variance = ( + torch.sum(((tensor_masked - mean) * broadcast_mask) ** 2) + / num_elements_not_masked + ) + return (tensor - mean) / torch.sqrt(variance + 1e-12) + + weights = torch.cat([parameter for parameter in self.scalar_parameters]) + + if self.training and self.dropout: + weights = torch.where( + self.dropout_mask.uniform_() > self.dropout, weights, self.dropout_fill + ) + + normed_weights = torch.nn.functional.softmax(weights, dim=0) + normed_weights = torch.split(normed_weights, split_size_or_sections=1) + + if not self.do_layer_norm: + pieces = [] + for weight, tensor in zip(normed_weights, tensors): + pieces.append(weight * tensor) + return self.gamma * sum(pieces) + + else: + mask_float = mask.float() + broadcast_mask = mask_float.unsqueeze(-1) + input_dim = tensors[0].size(-1) + num_elements_not_masked = torch.sum(mask_float) * input_dim + + pieces = [] + for weight, tensor in zip(normed_weights, tensors): + pieces.append( + weight + * _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked) + ) + return self.gamma * sum(pieces) + + def __repr__(self): + layer_scalars = [param.data.item() for param in self.scalar_parameters] + gamma = self.gamma.data.item() + + representation = "(" + for i in range(len(layer_scalars)): + representation += "\n\tLayer {} scalar: {:.4}".format(i, layer_scalars[i]) + return representation + "\n)\n(\n\tGamma: {:.4}\n)".format(gamma) diff --git a/caption/models/taggers/__init__.py b/caption/models/taggers/__init__.py new file mode 100644 index 0000000..a65d0bf --- /dev/null +++ b/caption/models/taggers/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +from .transformer_tagger import TransformerTagger + +__all__ = ["TransformerTagger"] diff --git a/caption/models/taggers/tagger_base.py b/caption/models/taggers/tagger_base.py new file mode 100644 index 0000000..9e91af5 --- /dev/null +++ b/caption/models/taggers/tagger_base.py @@ -0,0 +1,283 @@ +# -*- coding: utf-8 -*- +r""" +Word Tagger Base +============== + Abstract base class used to build new sequence tagging models + inside CAPTION. +""" +import sys +import pdb + +import numpy as np +import torch +import torch.nn as nn + +from caption.datasets import sequence_tagging_dataset +from caption.models.caption_base_model import CaptionModelBase +from caption.models.metrics import classification_report +from test_tube import HyperOptArgumentParser +from torchnlp.encoders import LabelEncoder +from torchnlp.encoders.text import stack_and_pad_tensors +from torchnlp.utils import collate_tensors + + +class Tagger(CaptionModelBase): + """ + Tagger base class. + + :param hparams: HyperOptArgumentParser containing the hyperparameters. + """ + + def __init__(self, hparams: HyperOptArgumentParser,) -> None: + super().__init__(hparams) + + def _build_model(self): + self.label_encoder = LabelEncoder( + self.hparams.tag_set.split(","), reserved_labels=[] + ) + + def _build_loss(self): + """ Initializes the loss function/s. """ + weights = ( + np.array([float(x) for x in self.hparams.class_weights.split(",")]) + if self.hparams.class_weights != "ignore" + else np.array([]) + ) + + if self.hparams.loss == "cross_entropy": + self.loss = nn.CrossEntropyLoss( + reduction="sum", + ignore_index=self.label_encoder.vocab_size, + weight=torch.tensor(weights, dtype=torch.float32) + if weights.any() + else None, + ) + else: + raise Exception(f"{self.hparams.loss} is not a valid loss option.") + + def _retrieve_dataset(self, data_hparams, train=True, val=True, test=True): + """ Retrieves task specific dataset """ + return sequence_tagging_dataset(data_hparams, train, val, test) + + @property + def default_slot_index(self): + """ Index of the default slot to be ignored. (e.g. 'O' in 'B-I-O' tags) """ + return 0 + + def predict(self, sample: dict) -> list: + """ Function that runs a model prediction, + :param sample: a dictionary that must contain the the 'source' sequence. + + Return: list with predictions + """ + if self.training: + self.eval() + + return_dict = False + if isinstance(sample, dict): + sample = [sample] + return_dict = True + + with torch.no_grad(): + model_input, _ = self.prepare_sample(sample, prepare_target=False) + model_out = self.forward(**model_input) + tag_logits = model_out["tags"] + _, pred_labels = tag_logits.topk(1, dim=-1) + + for i in range(pred_labels.size(0)): + sample_tags = pred_labels[i, :, :].view(-1) + tags = [ + self.label_encoder.index_to_token[sample_tags[j]] + for j in range(model_input["word_lengths"][i]) + ] + sample[i]["predicted_tags"] = " ".join(tags) + sample[i]["tagged_sequence"] = " ".join( + [ + word + "/" + tag + for word, tag in zip(sample[i]["text"].split(), tags) + ] + ) + + sample[i][ + "encoded_ground_truth_tags" + ] = self.label_encoder.batch_encode( + [tag for tag in sample[i]["tags"].split()] + ) + + if self.hparams.ignore_last_tag: + if ( + sample[i]["encoded_ground_truth_tags"][ + model_input["word_lengths"][i] - 1 + ] + == 1 + ): + sample[i]["encoded_ground_truth_tags"][ + model_input["word_lengths"][i] - 1 + ] = self.label_encoder.vocab_size + + if return_dict: + return sample[0] + + return sample + + def _compute_loss(self, model_out: dict, targets: dict) -> torch.tensor: + """ + Computes Loss value according to a loss function. + :param model_out: model specific output with predicted tag logits + a tensor [batch_size x seq_length x num_tags] + :param targets: Target tags [batch_size x seq_length] + """ + logits = model_out["tags"].view(-1, model_out["tags"].size(-1)) + labels = targets["tags"].view(-1) + return self.loss(logits, labels) + + def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict): + """ + Function that prepares a sample to input the model. + :param sample: list of dictionaries. + + Returns: + - dictionary with the expected model inputs. + - dictionary with the expected target values. + """ + sample = collate_tensors(sample) + inputs = self.encoder.prepare_sample(sample["text"], trackpos=True) + if not prepare_target: + return inputs, {} + + tags, _ = stack_and_pad_tensors( + [self.label_encoder.batch_encode(tags.split()) for tags in sample["tags"]], + padding_index=self.label_encoder.vocab_size, + ) + + if self.hparams.ignore_first_title: + first_tokens = tags[:, 0].clone() + tags[:, 0] = first_tokens.masked_fill_( + first_tokens == self._label_encoder.token_to_index["T"], + self.label_encoder.vocab_size, + ) + + # TODO is this still needed ? + if self.hparams.ignore_last_tag: + lengths = [len(tags.split()) for tags in sample["tags"]] + lengths = np.asarray(lengths) + k = 0 + for length in lengths: + if tags[k][length - 1] == 1: + tags[k][length - 1] = self.label_encoder.vocab_size + k += 1 + + targets = {"tags": tags} + return inputs, targets + + def _compute_metrics(self, outputs: list) -> dict: + """ + Private function that computes metrics of interest based on model predictions + and respective targets. + """ + predictions = [batch_out["val_prediction"]["tags"] for batch_out in outputs] + targets = [batch_out["val_target"]["tags"] for batch_out in outputs] + + predicted_tags, ground_truth = [], [] + for i in range(len(predictions)): + # Get logits and reshape predictions + batch_predictions = predictions[i] + logits = batch_predictions.view(-1, batch_predictions.size(-1)).cpu() + _, pred_labels = logits.topk(1, dim=-1) + + # Reshape targets + batch_targets = targets[i].view(-1).cpu() + + assert batch_targets.size() == pred_labels.view(-1).size() + ground_truth.append(batch_targets) + predicted_tags.append(pred_labels.view(-1)) + + return classification_report( + torch.cat(predicted_tags).numpy(), + torch.cat(ground_truth).numpy(), + padding=self.label_encoder.vocab_size, + labels=self.label_encoder.token_to_index, + ignore=self.default_slot_index, + ) + + @classmethod + def add_model_specific_args( + cls, parser: HyperOptArgumentParser + ) -> HyperOptArgumentParser: + """ Parser for Estimator specific arguments/hyperparameters. + :param parser: HyperOptArgumentParser obj + + Returns: + - updated parser + """ + parser = super(Tagger, Tagger).add_model_specific_args(parser) + parser.add_argument( + "--tag_set", + type=str, + default="L,U,T", + help="Task tags we want to use.\ + Note that the 'default' label should appear first", + ) + # Loss + parser.add_argument( + "--loss", + default="cross_entropy", + type=str, + help="Loss function to be used.", + choices=["cross_entropy"], + ) + parser.add_argument( + "--class_weights", + default="ignore", + type=str, + help='Weights for each of the classes we want to tag (e.g: "1.0,7.0,8.0").', + ) + ## Data args: + parser.add_argument( + "--data_type", + default="csv", + type=str, + help="The type of the file containing the training/dev/test data.", + choices=["csv"], + ) + parser.add_argument( + "--train_path", + default="data/dummy_train.csv", + type=str, + help="Path to the file containing the train data.", + ) + parser.add_argument( + "--dev_path", + default="data/dummy_test.csv", + type=str, + help="Path to the file containing the dev data.", + ) + parser.add_argument( + "--test_path", + default="data/dummy_test.csv", + type=str, + help="Path to the file containing the test data.", + ) + parser.add_argument( + "--loader_workers", + default=0, + type=int, + help=( + "How many subprocesses to use for data loading. 0 means that" + "the data will be loaded in the main process." + ), + ) + # Metric args: + parser.add_argument( + "--ignore_first_title", + default=False, + help="When used, this flag ignores T tags in the first position.", + action="store_true", + ) + parser.add_argument( + "--ignore_last_tag", + default=False, + help="When used, this flag ignores S tags in the last position.", + action="store_true", + ) + return parser diff --git a/caption/models/taggers/transformer_tagger.py b/caption/models/taggers/transformer_tagger.py new file mode 100644 index 0000000..a5b7838 --- /dev/null +++ b/caption/models/taggers/transformer_tagger.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +r""" +Transformer Tagger +============== + Model that uses a pretrained transformer model to tag sequences of words. +""" +import torch +from torch import nn + +from caption.models.encoders import Encoder, str2encoder +from caption.models.scalar_mix import ScalarMixWithDropout +from caption.models.taggers.tagger_base import Tagger +from caption.optimizers import build_optimizer +from caption.schedulers import build_scheduler +from test_tube import HyperOptArgumentParser + + +class TransformerTagger(Tagger): + """ + Word tagger that uses a pretrained Transformer model to extract features from text. + + :param hparams: HyperOptArgumentParser containing the hyperparameters. + """ + + def __init__(self, hparams: HyperOptArgumentParser,) -> None: + super().__init__(hparams) + + def _build_model(self) -> Tagger: + """ + Initializes the estimator architecture. + """ + super()._build_model() + self.layer = ( + int(self.hparams.layer) + if self.hparams.layer != "mix" + else self.hparams.layer + ) + + if self.hparams.concat_tokens: + self.tagging_head = nn.Linear( + 2 * self.encoder.output_units, self.label_encoder.vocab_size + ) + else: + self.tagging_head = nn.Linear( + self.encoder.output_units, self.label_encoder.vocab_size + ) + + self.dropout = nn.Dropout(self.hparams.dropout) + self.scalar_mix = ( + ScalarMixWithDropout( + mixture_size=self.encoder.num_layers, + do_layer_norm=True, + dropout=self.hparams.scalar_mix_dropout, + ) + if self.layer == "mix" + else None + ) + + def _build_encoder(self, hparams: HyperOptArgumentParser) -> Encoder: + """ + Initializes the encoder. + """ + return str2encoder[self.hparams.encoder_model].from_pretrained(hparams) + + def configure_optimizers(self): + """ Sets different Learning rates for different parameter groups. """ + parameters = [ + {"params": self.tagging_head.parameters()}, + { + "params": self.encoder.parameters(), + "lr": self.hparams.encoder_learning_rate, + }, + ] + if self.scalar_mix: + parameters.append( + { + "params": self.scalar_mix.parameters(), + "lr": self.hparams.encoder_learning_rate, + } + ) + optimizer = build_optimizer(parameters, self.hparams) + scheduler = build_scheduler(optimizer, self.hparams) + return [optimizer], [scheduler] + + def forward( + self, + tokens: torch.tensor, + lengths: torch.tensor, + word_boundaries: torch.tensor, + word_lengths: torch.tensor, + ) -> dict: + """ + Function that encodes a sequence and returns the punkt and cap tags. + + :param tokens: wordpiece tokens [batch_size x wordpiece_length] + :param lengths: wordpiece sequences lengths [batch_size] + :param word_boundaries: wordpiece token positions [batch_size x word_length] + :param word_lengths: word sequences lengths [batch_size] + + Return: Dictionary with model outputs to be passed to the loss function. + """ + # When using just one GPU this should not change behavior + # but when splitting batches across GPU the tokens have padding + # from the entire original batch + if self.trainer and self.trainer.use_dp and self.trainer.num_gpus > 1: + tokens = tokens[:, : lengths.max()] + + encoder_out = self.encoder(tokens, lengths) + + if self.scalar_mix: + embeddings = self.scalar_mix(encoder_out["all_layers"], encoder_out["mask"]) + else: + try: + embeddings = encoder_out["all_layers"][self.layer] + except IndexError: + raise Exception( + "Invalid model layer {}. Only {} layers available".format( + self.hparams.layer, self.encoder.num_layers + ) + ) + + word_embeddings = torch.cat( + [ + w.index_select(0, i).unsqueeze(0) + for w, i in zip(embeddings, word_boundaries) + ] + ) + if self.hparams.concat_tokens: + concat = [ + torch.cat( + [word_embeddings[i, j, :], word_embeddings[i, j + 1, :]], dim=0 + ) + if j < word_embeddings.shape[1] - 1 + else torch.cat( + [word_embeddings[i, j, :], word_embeddings[i, j, :]], dim=0 + ) + for i in range(word_embeddings.shape[0]) + for j in range(word_embeddings.shape[1]) + ] + new_embedds = torch.stack(concat).view( + word_embeddings.shape[0], + word_embeddings.shape[1], + 2 * self.encoder.output_units, + ) + tag_predictions = self.tagging_head(self.dropout(new_embedds)) + else: + tag_predictions = self.tagging_head(self.dropout(word_embeddings)) + + return { + "tags": tag_predictions, + } + + @staticmethod + def add_model_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + parser = super(TransformerTagger, TransformerTagger).add_model_specific_args( + parser + ) + # Parameters for the Encoder model + parser.add_argument( + "--encoder_model", + default="RoBERTa", + type=str, + help="Encoder model to be used.", + choices=["BERT", "RoBERTa", "XLM-RoBERTa"], + ) + parser.add_argument( + "--pretrained_model", + default="roberta.base", + type=str, + help=( + "Encoder pretrained model to be used. " + "(e.g. roberta.base or roberta.large" + ), + ) + parser.add_argument( + "--encoder_learning_rate", + default=1e-05, + type=float, + help="Encoder specific learning rate.", + ) + parser.opt_list( + "--dropout", + default=0.1, + type=float, + help="Dropout to be applied in feed forward net on top.", + tunable=True, + options=[0.1, 0.2, 0.3, 0.4, 0.5], + ) + parser.add_argument( + "--layer", + default="-1", + type=str, + help=( + "Encoder model layer to be used. Last one is the default. " + "If 'mix' all the encoder layer's will be combined with layer-wise attention" + ), + ) + parser.opt_list( + "--scalar_mix_dropout", + default=0.0, + type=float, + tunable=False, + options=[0.0, 0.05, 0.1, 0.15, 0.2], + help=( + "The ammount of layer wise dropout when using scalar_mix option for layer pooling. " + "Only applicable if the 'layer' parameters is set to mix" + ), + ) + parser.add_argument( + "--concat_tokens", + default=False, + help="Apply concatenation of consecutive words to feed to the linear projection", + action="store_true", + ) + return parser diff --git a/caption/models/utils.py b/caption/models/utils.py new file mode 100644 index 0000000..f784d98 --- /dev/null +++ b/caption/models/utils.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +import torch +from caption.tokenizers import TextEncoderBase + + +def mask_fill( + fill_value: float, + tokens: torch.tensor, + embeddings: torch.tensor, + padding_index: int, +) -> torch.tensor: + """ + Function that masks embeddings representing padded elements. + :param fill_value: the value to fill the embeddings belonging to padded tokens. + :param tokens: The input sequences [bsz x seq_len]. + :param embeddings: word embeddings [bsz x seq_len x hiddens]. + :param padding_index: Index of the padding token. + """ + padding_mask = tokens.eq(padding_index).unsqueeze(-1) + return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings) + + +def mask_tokens( + inputs: torch.tensor, + tokenizer: TextEncoderBase, + mlm_probability: float = 0.15, + ignore_index: int = -100, +): + """ Mask tokens function from Hugging Face that prepares masked tokens inputs/labels for + masked language modeling. + + :param inputs: Input tensor to be masked. + :param tokenizer: COMET text encoder. + :param mlm_probability: Probability of masking a token (default: 15%). + :param ignore_index: Specifies a target value that is ignored and does not contribute to + the input gradient (default: -100). + + Returns: + - Tuple with input to the model and the target. + """ + if tokenizer.mask_index is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for masked language" + "modeling. Remove the --mlm flag if you want to use this tokenizer." + ) + + labels = inputs.clone() + probability_matrix = torch.full(labels.shape, mlm_probability) + special_tokens_mask = [ + tokenizer.get_special_tokens_mask(val) for val in labels.tolist() + ] + probability_matrix.masked_fill_( + torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0 + ) + padding_mask = labels.eq(tokenizer.padding_index) + probability_matrix.masked_fill_(padding_mask, value=0.0) + + masked_indices = torch.bernoulli(probability_matrix).bool() + labels[~masked_indices] = ignore_index # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with ([MASK]) + indices_replaced = ( + torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices + ) + + inputs[indices_replaced] = tokenizer.mask_index + + # 10% of the time, we replace masked input tokens with random word + indices_random = ( + torch.bernoulli(torch.full(labels.shape, 0.5)).bool() + & masked_indices + & ~indices_replaced + ) + random_words = torch.randint(tokenizer.vocab_size, labels.shape, dtype=torch.long) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels diff --git a/caption/optimizers/__init__.py b/caption/optimizers/__init__.py new file mode 100644 index 0000000..e5641eb --- /dev/null +++ b/caption/optimizers/__init__.py @@ -0,0 +1,27 @@ +""" Module defining Optimizers. """ +from .adamw import AdamW +from .radam import RAdam +from .optim_args import OptimArgs +from .adam import Adam +from .adamax import Adamax + +str2optimizer = {"AdamW": AdamW, "RAdam": RAdam, "Adam": Adam, "Adamax": Adamax} + + +def build_optimizer(params, hparams): + """ + Function that builds an optimizer from the HyperOptArgumentParser + :param params: Model parameters + :param hparams: HyperOptArgumentParser + """ + return str2optimizer[hparams.optimizer].from_hparams(params, hparams) + + +def add_optimizer_args(parser, optimizer: str): + try: + return str2optimizer[optimizer].add_optim_specific_args(parser) + except KeyError: + raise Exception(f"{optimizer} is not a valid optimizer option!") + + +__all__ = ["AdamW", "RAdam", "Adam", "Adamax"] diff --git a/caption/optimizers/adam.py b/caption/optimizers/adam.py new file mode 100644 index 0000000..f1f997f --- /dev/null +++ b/caption/optimizers/adam.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +from test_tube import HyperOptArgumentParser + +from torch.optim import Adam as RegularAdam + +from .optim_args import OptimArgs + + +class Adam(RegularAdam, OptimArgs): + """ + Wrapper for the pytorch Adam optimizer. + https://pytorch.org/docs/stable/_modules/torch/optim/adam.html + + :param params: Model parameters + :param lr: learning rate. + :param betas: Adams beta parameters (b1, b2). + :param eps: Adams epsilon. + :param weight_decay: Weight decay. + :param correct_bias: Can be set to False to avoid correcting bias in Adam. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: list = [0.9, 0.999], + eps: float = 1e-6, + weight_decay: float = 0.0, + amsgrad: bool = False, + ) -> None: + super(Adam, self).__init__( + params=params, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ) + + @classmethod + def from_hparams(cls, params, hparams): + """ + Initializes the scheduler from the parameters in the HyperOptArgumentParser + """ + return Adam( + params, + hparams.learning_rate, + (hparams.b1, hparams.b2), + hparams.eps, + hparams.weight_decay, + hparams.amsgrad, + ) + + @staticmethod + def add_optim_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ + Functions that parses Optimizer specific arguments and adds + them to the Namespace + :param parser: + """ + parser = super(Adam, Adam).add_optim_specific_args(parser) + parser.add_argument( + "--b1", default=0.9, type=float, help="Adams beta parameters (b1, b2)." + ) + parser.add_argument( + "--b2", default=0.999, type=float, help="Adams beta parameters (b1, b2)." + ) + parser.add_argument("--eps", default=1e-6, type=float, help="Adams epsilon.") + parser.add_argument( + "--weight_decay", default=0.0, type=float, help="Weight decay." + ) + parser.add_argument( + "--amsgrad", + default=False, + help="Whether to use the AMSGrad variant of this algorithm from the paper:\ + 'On the Convergence of Adam and Beyond'", + action="store_true", + ) + return parser diff --git a/caption/optimizers/adamax.py b/caption/optimizers/adamax.py new file mode 100644 index 0000000..af2aac3 --- /dev/null +++ b/caption/optimizers/adamax.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +from test_tube import HyperOptArgumentParser + +from torch.optim import Adamax as TorchAdamax + +from .optim_args import OptimArgs + + +class Adamax(TorchAdamax, OptimArgs): + """ + Wrapper for the pytorch Adamax optimizer. + https://pytorch.org/docs/stable/_modules/torch/optim/adamax.html + + :param params: Model parameters + :param lr: learning rate. + :param betas: Adams beta parameters (b1, b2). + :param eps: Adams epsilon. + :param weight_decay: Weight decay. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: list = [0.9, 0.999], + eps: float = 1e-6, + weight_decay: float = 0.0, + ) -> None: + super(Adamax, self).__init__( + params=params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay + ) + + @classmethod + def from_hparams(cls, params, hparams): + """ + Initializes the scheduler from the parameters in the HyperOptArgumentParser + """ + return Adamax( + params, + hparams.learning_rate, + (hparams.b1, hparams.b2), + hparams.eps, + hparams.weight_decay, + ) + + @staticmethod + def add_optim_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ + Functions that parses Optimizer specific arguments and adds + them to the Namespace + :param parser: + """ + parser = super(Adamax, Adamax).add_optim_specific_args(parser) + parser.add_argument( + "--b1", default=0.9, type=float, help="Adams beta parameters (b1, b2)." + ) + parser.add_argument( + "--b2", default=0.999, type=float, help="Adams beta parameters (b1, b2)." + ) + parser.add_argument("--eps", default=1e-6, type=float, help="Adams epsilon.") + parser.add_argument( + "--weight_decay", default=0.0, type=float, help="Weight decay." + ) + return parser diff --git a/caption/optimizers/adamw.py b/caption/optimizers/adamw.py new file mode 100644 index 0000000..827c4ab --- /dev/null +++ b/caption/optimizers/adamw.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +from test_tube import HyperOptArgumentParser + +from transformers import AdamW as HuggingFaceAdamW + +from .optim_args import OptimArgs + + +class AdamW(HuggingFaceAdamW, OptimArgs): + """ + Wrapper for the huggingface AdamW optimizer. + https://huggingface.co/transformers/v2.1.1/main_classes/optimizer_schedules.html#adamw + + :param params: Model parameters + :param lr: learning rate. + :param betas: Adams beta parameters (b1, b2). + :param eps: Adams epsilon. + :param weight_decay: Weight decay. + :param correct_bias: Can be set to False to avoid correcting bias in Adam. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: list = [0.9, 0.999], + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + ) -> None: + super(AdamW, self).__init__( + params=params, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias, + ) + + @classmethod + def from_hparams(cls, params, hparams): + """ + Initializes the scheduler from the parameters in the HyperOptArgumentParser + """ + return AdamW( + params, + hparams.learning_rate, + (hparams.b1, hparams.b2), + hparams.eps, + hparams.weight_decay, + hparams.correct_bias, + ) + + @staticmethod + def add_optim_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ + Functions that parses Optimizer specific arguments and adds + them to the Namespace + :param parser: + """ + parser = super(AdamW, AdamW).add_optim_specific_args(parser) + parser.add_argument( + "--b1", default=0.9, type=float, help="Adams beta parameters (b1, b2)." + ) + parser.add_argument( + "--b2", default=0.999, type=float, help="Adams beta parameters (b1, b2)." + ) + parser.add_argument("--eps", default=1e-6, type=float, help="Adams epsilon.") + parser.add_argument( + "--weight_decay", default=0.0, type=float, help="Weight decay." + ) + parser.add_argument( + "--correct_bias", + default=False, + help="If this flag is on the correct_bias AdamW parameter is set to True.", + action="store_true", + ) + return parser diff --git a/caption/optimizers/optim_args.py b/caption/optimizers/optim_args.py new file mode 100644 index 0000000..d187642 --- /dev/null +++ b/caption/optimizers/optim_args.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from test_tube import HyperOptArgumentParser + + +class OptimArgs(object): + """ + Optimizer classes can Inheritance directly from the Pytorch Optimizer + class but we want to extend the normal Optimizer class behavior + with the add_optim_specific_args function. + + This class defines an Interface for adding Optimizer specific arguments + to the Namespace + """ + + @classmethod + def from_hparams(cls, params, hparams): + """ + Initializes the optimizer from the parameters in the HyperOptArgumentParser + """ + raise NotImplementedError + + @staticmethod + def add_optim_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ + Functions that parses Optimizer specific arguments and adds + them to the Namespace + :param parser: + """ + parser.opt_list( + "--learning_rate", + default=5e-5, + type=float, + tunable=True, + options=[1e-05, 3e-05, 5e-05, 8e-05, 1e-04], + help="Optimizer learning rate.", + ) + return parser diff --git a/caption/optimizers/radam.py b/caption/optimizers/radam.py new file mode 100644 index 0000000..69b5335 --- /dev/null +++ b/caption/optimizers/radam.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- +import math + +import torch +from torch.optim.optimizer import Optimizer + +from test_tube import HyperOptArgumentParser + +from .optim_args import OptimArgs + + +class RAdam(Optimizer, OptimArgs): + """ + RAdam optimizer: https://arxiv.org/abs/1908.03265 + Check the original implementation: + - https://github.com/LiyuanLucasLiu/RAdam + Medium Post: + - https://medium.com/@lessw/new-state-of-the-art-ai-optimizer-rectified-adam-radam-5d854730807b + + :param params: Model parameters + :param lr: Learning rate to be used. + :param betas: Adams beta parameters. + :param eps: Adams epsilon. + :param weight_decay: Weight decay. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: tuple = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + degenerated_to_sgd: bool = True, + ) -> None: + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if ( + isinstance(params, (list, tuple)) + and len(params) > 0 + and isinstance(params[0], dict) + ): + for param in params: + if "betas" in param and ( + param["betas"][0] != betas[0] or param["betas"][1] != betas[1] + ): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)], + ) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + @classmethod + def from_hparams(cls, params, hparams): + """ + Initializes the scheduler from the parameters in the HyperOptArgumentParser + """ + return RAdam( + params, + hparams.learning_rate, + (hparams.b1, hparams.b2), + hparams.eps, + hparams.weight_decay, + hparams.degenerated_to_sgd, + ) + + @staticmethod + def add_optim_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ + Functions that parses Optimizer specific arguments and adds + them to the Namespace + :param parent_parser: + """ + parser = super(RAdam, RAdam).add_optim_specific_args(parser) + parser.add_argument( + "--b1", default=0.9, type=float, help="Adams beta parameters (b1, b2)." + ) + parser.add_argument( + "--b2", default=0.999, type=float, help="Adams beta parameters (b1, b2)." + ) + parser.add_argument("--eps", default=1e-6, type=float, help="Adams epsilon.") + parser.add_argument( + "--weight_decay", default=0.0, type=float, help="Weight decay." + ) + parser.add_argument( + "--degenerated_to_sgd", + default=False, + help="If this flag is on the degenerated_to_sgd RAdam parameter is set to True.", + action="store_true", + ) + return parser + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state["step"]) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group["weight_decay"] != 0: + p_data_fp32.add_( + -group["weight_decay"] * group["lr"], p_data_fp32 + ) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group["weight_decay"] != 0: + p_data_fp32.add_( + -group["weight_decay"] * group["lr"], p_data_fp32 + ) + p_data_fp32.add_(-step_size * group["lr"], exp_avg) + p.data.copy_(p_data_fp32) + + return loss diff --git a/caption/schedulers/__init__.py b/caption/schedulers/__init__.py new file mode 100644 index 0000000..9606aef --- /dev/null +++ b/caption/schedulers/__init__.py @@ -0,0 +1,29 @@ +""" Module defining Schedulers. """ +from .linear_warmup import LinearWarmup +from .constant_lr import ConstantLR +from .warmup_constant import WarmupConstant + + +str2scheduler = { + "linear_warmup": LinearWarmup, + "constant": ConstantLR, + "warmup_constant": WarmupConstant, +} + + +def build_scheduler(optimizer, hparams): + """ + Function that builds a scheduler from the HyperOptArgumentParser + :param hparams: HyperOptArgumentParser + """ + return str2scheduler[hparams.scheduler].from_hparams(optimizer, hparams) + + +def add_scheduler_args(parser, scheduler: str): + try: + return str2scheduler[scheduler].add_scheduler_specific_args(parser) + except KeyError: + raise Exception(f"{scheduler} is not a valid scheduler option!") + + +__all__ = ["LinearWarmup", "ConstantLR", "WarmupConstant"] diff --git a/caption/schedulers/constant_lr.py b/caption/schedulers/constant_lr.py new file mode 100644 index 0000000..3f76f5f --- /dev/null +++ b/caption/schedulers/constant_lr.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from test_tube import HyperOptArgumentParser + +from .scheduler_args import SchedulerArgs + + +class ConstantLR(LambdaLR, SchedulerArgs): + """ + Constant learning rate schedule + + Wrapper for the huggingface Constant LR Scheduler. + https://huggingface.co/transformers/v2.1.1/main_classes/optimizer_schedules.html + + :param optimizer: torch.optim.Optimizer + :param last_epoch: + """ + + def __init__(self, optimizer: Optimizer, last_epoch: int = -1) -> None: + super(ConstantLR, self).__init__(optimizer, lambda _: 1, last_epoch) + + @classmethod + def from_hparams(cls, optimizer, hparams): + """ + Initializes the scheduler from the parameters in the HyperOptArgumentParser + """ + return ConstantLR(optimizer, hparams.last_epoch) + + @staticmethod + def add_scheduler_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ + Functions that parses Optimizer specific arguments and adds + them to the Namespace + :param parser: + """ + return super(ConstantLR, ConstantLR).add_scheduler_specific_args(parser) diff --git a/caption/schedulers/linear_warmup.py b/caption/schedulers/linear_warmup.py new file mode 100644 index 0000000..b970fea --- /dev/null +++ b/caption/schedulers/linear_warmup.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +import sys + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from test_tube import HyperOptArgumentParser + +from .scheduler_args import SchedulerArgs + + +class LinearWarmup(LambdaLR, SchedulerArgs): + """ + Create a schedule with a learning rate that decreases linearly after + linearly increasing during a warmup period. + + :param optimizer: torch.optim.Optimizer + :param warmup_steps: Linearly increases learning rate from 0 to 1*learning_rate over warmup_steps. + :param num_training_steps: Linearly decreases learning rate from 1*learning_rate to 0. over remaining + t_total - warmup_steps steps. + :param last_epoch: + """ + + def __init__( + self, + optimizer: Optimizer, + warmup_steps: int, + num_training_steps: int, + last_epoch: int = -1, + ) -> None: + def lr_lambda(current_step): + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + return max( + 0.0, + float(num_training_steps - current_step) + / float(max(1, num_training_steps - warmup_steps)), + ) + + super(LinearWarmup, self).__init__(optimizer, lr_lambda, last_epoch) + + @classmethod + def from_hparams(cls, optimizer, hparams): + """ + Initializes the scheduler from the parameters in the HyperOptArgumentParser + """ + return LinearWarmup( + optimizer, + hparams.warmup_steps, + hparams.num_training_steps, + hparams.last_epoch, + ) + + @staticmethod + def add_scheduler_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ + Functions that parses Optimizer specific arguments and adds + them to the Namespace + :param parent_parser: + """ + parser = super(LinearWarmup, LinearWarmup).add_scheduler_specific_args(parser) + parser.add_argument( + "--warmup_steps", + type=int, + default=1, + help="Linearly increases learning rate from 0 to 1 over warmup_steps.", + ) + parser.add_argument( + "--num_training_steps", + type=int, + default=sys.maxsize, + help="Linearly decreases learning rate from 1*learning_rate to 0*learning_rate over \ + remaining t_total - warmup_steps steps.", + ) + return parser diff --git a/caption/schedulers/scheduler_args.py b/caption/schedulers/scheduler_args.py new file mode 100644 index 0000000..441106a --- /dev/null +++ b/caption/schedulers/scheduler_args.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +from test_tube import HyperOptArgumentParser + + +class SchedulerArgs(object): + """ + The Schedulers can Inheritance directly from the Pytorch lr schedulers + but we want to extend the normal lr scheduler class behavior + with the add_scheduler_specific_args function. + + This class defines an Interface for adding Scheduler specific arguments + to the Namespace and a method to build from the HyperOptArgumentParser + """ + + @classmethod + def from_hparams(cls, optimizer, hparams): + """ + Initializes the scheduler from the parameters in the HyperOptArgumentParser + """ + raise NotImplementedError + + @staticmethod + def add_scheduler_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ + Functions that parses scheduler specific arguments and adds + them to the Namespace + :param parser: + """ + parser.add_argument( + "--last_epoch", default=-1, type=int, help="Scheduler last epoch step" + ) + return parser diff --git a/caption/schedulers/warmup_constant.py b/caption/schedulers/warmup_constant.py new file mode 100644 index 0000000..b7ee25e --- /dev/null +++ b/caption/schedulers/warmup_constant.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +from test_tube import HyperOptArgumentParser + +from .scheduler_args import SchedulerArgs + + +class WarmupConstant(LambdaLR, SchedulerArgs): + """ + Warmup Linear scheduler. + 1) Linearly increases learning rate from 0 to 1 over warmup_steps + training steps. + 2) Keeps the learning rate constant afterwards. + + :param optimizer: torch.optim.Optimizer + :param warmup_steps: Linearly increases learning rate from 0 to 1 over warmup_steps. + :param last_epoch: + """ + + def __init__( + self, optimizer: Optimizer, warmup_steps: int, last_epoch: int = -1 + ) -> None: + def lr_lambda(current_step): + if current_step < warmup_steps: + return float(current_step) / float(max(1.0, warmup_steps)) + return 1.0 + + super(WarmupConstant, self).__init__(optimizer, lr_lambda, last_epoch) + + @classmethod + def from_hparams(cls, optimizer, hparams): + """ + Initializes the scheduler from the parameters in the HyperOptArgumentParser + """ + return WarmupConstant(optimizer, hparams.warmup_steps, hparams.last_epoch) + + @staticmethod + def add_scheduler_specific_args( + parser: HyperOptArgumentParser, + ) -> HyperOptArgumentParser: + """ + Functions that parses Optimizer specific arguments and adds + them to the Namespace + :param parent_parser: + """ + parser = super(WarmupConstant, WarmupConstant).add_scheduler_specific_args( + parser + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=1, + help="Linearly increases learning rate from 0*learning_rate to 1*learning_rate over warmup_steps.", + ) + return parser diff --git a/caption/testing.py b/caption/testing.py new file mode 100644 index 0000000..680992e --- /dev/null +++ b/caption/testing.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +import json +import logging +import pdb + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import SequentialSampler, BatchSampler +from tqdm import tqdm + +from caption.models.metrics import classification_report +from caption.models import str2model +from pytorch_lightning import Trainer +from test_tube import HyperOptArgumentParser +from torchnlp.encoders.text import stack_and_pad_tensors +from torchnlp.utils import collate_tensors, sampler_to_iterator + +log = logging.getLogger("Shell") + + +def setup_testing(hparams: HyperOptArgumentParser): + """ + Setup for the testing loop. + :param hparams: HyperOptArgumentParser + + Returns: + - CAPTION model + - Test Set to be used. + """ + tags_csv_file = "/".join(hparams.checkpoint.split("/")[:-1] + ["meta_tags.csv"]) + tags = pd.read_csv(tags_csv_file, header=None, index_col=0, squeeze=True).to_dict() + model = str2model[tags["model"]].load_from_metrics( + weights_path=hparams.checkpoint, tags_csv=tags_csv_file + ) + log.info(model.hparams) + + # Make sure model is in prediction mode + model.eval() + model.freeze() + log.info(f"{hparams.checkpoint} reloaded for testing.") + log.info(f"Testing with {hparams.test_path} testset.") + model._test_dataset = model._retrieve_dataset(hparams, train=False, val=False)[0] + return model, model._test_dataset + + +def run_testing(hparams: HyperOptArgumentParser): + model, testset = setup_testing(hparams) + log.info("Testing model in CPU. This might take a while.") + + sampler = SequentialSampler(testset) + iterator = sampler_to_iterator(testset, sampler) + predictions = [ + model.predict(sample) + for i, sample in tqdm(enumerate(iterator), total=len(testset)) + ] + + predicted_tags = model.label_encoder.batch_encode( + [tag for pred in predictions for tag in pred["predicted_tags"].split()] + ) + + ground_truth_tags = torch.stack( + [tag for pred in predictions for tag in pred["encoded_ground_truth_tags"]] + ) + + metrics = classification_report( + np.array(predicted_tags), + np.array(ground_truth_tags), + padding=model.label_encoder.vocab_size, + labels=model.label_encoder.token_to_index, + ignore=model.default_slot_index, + ) + log.info("-- Test metrics:\n{}".format(json.dumps(metrics, indent=1))) + + testing_output_file = "/".join( + hparams.checkpoint.split("/")[:-1] + + [hparams.checkpoint.split(".")[0].split("/")[-1] + "_predictions.json"] + ) + + predictions = [ + {k: v for k, v in d.items() if k != "encoded_ground_truth_tags"} + for d in predictions + ] + + with open(testing_output_file, "w") as outfile: + json.dump({"results": metrics, "predictions": predictions}, outfile) diff --git a/caption/tokenizers/__init__.py b/caption/tokenizers/__init__.py new file mode 100644 index 0000000..2f72057 --- /dev/null +++ b/caption/tokenizers/__init__.py @@ -0,0 +1,11 @@ +from .tokenizer_base import TextEncoderBase +from .bert_tokenizer import BERTTextEncoder +from .roberta_tokenizer import RoBERTaTextEncoder +from .hf_roberta_tokenizer import HfRoBERTaTextEncoder + +__all__ = [ + "BERTTextEncoder", + "TextEncoderBase", + "RoBERTaTextEncoder", + "HfRoBERTaTextEncoder", +] diff --git a/caption/tokenizers/bert_tokenizer.py b/caption/tokenizers/bert_tokenizer.py new file mode 100644 index 0000000..fb7ff57 --- /dev/null +++ b/caption/tokenizers/bert_tokenizer.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +r""" +Hugging Face BERT Tokenizer class +============== + Hugging Face BERT tokenizer wrapper. +""" +import torch + +from .tokenizer_base import TextEncoderBase +from torchnlp.encoders.text.text_encoder import TextEncoder +from transformers import BertTokenizer + + +class BERTTextEncoder(TextEncoderBase): + """ + BERT tokenizer. + + :param model: BERT model to be used. + """ + + def __init__(self, model: str) -> None: + super().__init__() + self.tokenizer = BertTokenizer.from_pretrained(model) + # Properties from the base class + self.stoi = self.tokenizer.vocab + self.itos = self.tokenizer.ids_to_tokens + self._bos_index = self.tokenizer.cls_token_id + self._pad_index = self.tokenizer.pad_token_id + self._eos_index = self.tokenizer.sep_token_id + self._unk_index = self.tokenizer.unk_token_id + self._mask_index = self.tokenizer.mask_token_id + + def encode_trackpos(self, sequence: str) -> torch.Tensor: + """ Encodes a 'sequence' and keeps the alignments with the respective tags. + :param sequence: String 'sequence' to encode. + + Returns: + - torch.Tensor: Encoding of the 'sequence'. + - torch.Tensor: Alignment indexes + """ + sequence = TextEncoder.encode(self, sequence) + tag_index, vector = [], [self._bos_index,] + for index, token in enumerate(sequence.split()): + tag_index.append(len(vector)) + vector = vector + self.tokenizer.encode(token, add_special_tokens=False) + vector.append(self._eos_index) + return torch.tensor(vector), torch.tensor(tag_index) + + def encode(self, sequence: str) -> torch.Tensor: + """ Encodes a 'sequence'. + :param sequence: String 'sequence' to encode. + + Returns: + - torch.Tensor: Encoding of the 'sequence'. + """ + sequence = TextEncoder.encode(self, sequence) + vector = self.tokenizer.encode(sequence) + return torch.tensor(vector) diff --git a/caption/tokenizers/hf_roberta_tokenizer.py b/caption/tokenizers/hf_roberta_tokenizer.py new file mode 100644 index 0000000..be27d48 --- /dev/null +++ b/caption/tokenizers/hf_roberta_tokenizer.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +r""" +Hugging Face RoBERTa Tokenizer class +============== + Hugging Face RoBERTa tokenizer wrapper. +""" +import torch + +from .tokenizer_base import TextEncoderBase +from torchnlp.encoders.text.text_encoder import TextEncoder +from transformers import RobertaTokenizer + + +class HfRoBERTaTextEncoder(TextEncoderBase): + """ + Hugging Face RoBERTa tokenizer. + + :param model: RoBERTa model to be used. + """ + + def __init__(self, model: str) -> None: + super().__init__() + self.tokenizer = RobertaTokenizer.from_pretrained(model) + # Properties from the base class + self.stoi = {} # ignored + self.itos = {} # ignored + self._bos_index = self.tokenizer.bos_token_id + self._pad_index = self.tokenizer.pad_token_id + self._eos_index = self.tokenizer.sep_token_id + self._unk_index = self.tokenizer.unk_token_id + # Hugging face mask_token_id is wrong! this is the real value: + self._mask_index = 250001 + # TODO update transformers version to avoid the above bug. + # https://github.com/huggingface/transformers/pull/2509 + + @property + def vocab_size(self) -> int: + """ + Returns: + int: Number of tokens in the dictionary. + """ + return self.tokenizer.vocab_size + + def encode_trackpos(self, sequence: str) -> torch.Tensor: + """ Encodes a 'sequence' and keeps the alignments with the respective tags. + :param sequence: String 'sequence' to encode. + + Returns: + - torch.Tensor: Encoding of the 'sequence'. + - torch.Tensor: Alignment indexes + """ + sequence = TextEncoder.encode(self, sequence) + tag_index, vector = [], [self._bos_index,] + for index, token in enumerate(sequence.split()): + tag_index.append(len(vector)) + vector = vector + self.tokenizer.encode(token, add_special_tokens=False) + vector.append(self._eos_index) + return torch.tensor(vector), torch.tensor(tag_index) + + def encode(self, sequence: str) -> torch.Tensor: + """ Encodes a 'sequence'. + :param sequence: String 'sequence' to encode. + + Returns: + - torch.Tensor: Encoding of the 'sequence'. + """ + sequence = TextEncoder.encode(self, sequence) + vector = self.tokenizer.encode(sequence) + return torch.tensor(vector) diff --git a/caption/tokenizers/roberta_tokenizer.py b/caption/tokenizers/roberta_tokenizer.py new file mode 100644 index 0000000..2457e1f --- /dev/null +++ b/caption/tokenizers/roberta_tokenizer.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +r""" +RoBERTa Tokenizer class +============== + Fairseq RoBERTa tokenizer wrapper. +""" +import torch + +from .tokenizer_base import TextEncoderBase +from torchnlp.encoders.text.text_encoder import TextEncoder + + +class RoBERTaTextEncoder(TextEncoderBase): + """ + RoBERTa encoder from Fairseq. + + :param tokenizer_func: RoBERTa tokenization function. + This can be easily obtain from the fairseq model (e.g: roberta.encode callable) + :param vocabulary: the dictionary containing the RoBERTa vocabulary. + This can be easily obtain from the fairseq model + (e.g: roberta.task.source_dictionary.__dict__['indices']) + """ + + def __init__(self, encode_func: callable, vocabulary: dict) -> None: + super().__init__() + + self.encode_func = encode_func + # Properties from the base class + self.stoi = vocabulary + self.itos = {v: k for k, v in vocabulary.items()} + self._pad_index = self.stoi[""] + self._eos_index = self.stoi[""] + self._unk_index = self.stoi[""] + self._bos_index = self.stoi[""] + self._mask_index = self.stoi[""] + + def encode_trackpos(self, sequence: str) -> torch.Tensor: + """ Encodes a 'sequence' and keeps the alignments with the respective tags. + :param sequence: String 'sequence' to encode. + + Returns: + - torch.Tensor: Encoding of the 'sequence'. + - torch.Tensor: Alignment indexes + """ + sequence = TextEncoder.encode(self, sequence) + tag_index, vector = [], [self._bos_index,] + tokens = sequence.split() + # Add whitespace to each token to prevent Ġ + tokens = [tokens[0]] + [" " + token for token in tokens[1:]] + for index, token in enumerate(tokens): + tag_index.append(len(vector)) + vector = vector + self.encode_func(token)[1:-1].tolist() + vector.append(self._eos_index) + return torch.tensor(vector), torch.tensor(tag_index) + + def encode(self, sequence: str) -> torch.Tensor: + """ Encodes a 'sequence'. + :param sequence: String 'sequence' to encode. + + Returns: + - torch.Tensor: Encoding of the 'sequence'. + """ + sequence = TextEncoder.encode(self, sequence) + return self.encode_func(sequence) diff --git a/caption/tokenizers/tokenizer_base.py b/caption/tokenizers/tokenizer_base.py new file mode 100644 index 0000000..e289836 --- /dev/null +++ b/caption/tokenizers/tokenizer_base.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +r""" +Tokenizer base class +============== + Base class for all tokenizers. +""" +import torch + +from torchnlp.encoders import Encoder +from torchnlp.encoders.text import stack_and_pad_tensors +from torchnlp.encoders.text.text_encoder import TextEncoder + + +class TextEncoderBase(TextEncoder): + """ + Base class for the specific tokenizers of each model. + """ + + def __init__(self) -> None: + self.enforce_reversible = False + + @property + def unk_index(self) -> int: + """ Returns the index used for the unknown token. """ + return self._unk_index + + @property + def bos_index(self) -> int: + """ Returns the index used for the begin-of-sentence token. """ + return self._bos_index + + @property + def eos_index(self) -> int: + """ Returns the index used for the end-of-sentence token. """ + return self._eos_index + + @property + def mask_index(self) -> int: + """ Returns the index used for the end-of-sentence token. """ + return self._mask_index + + @property + def padding_index(self) -> int: + """ Returns the index used for padding. """ + return self._pad_index + + @property + def vocab(self) -> list: + """ + Returns: + list: List of tokens in the dictionary. + """ + return self.stoi + + @property + def vocab_size(self) -> int: + """ + Returns: + int: Number of tokens in the dictionary. + """ + return len(self.itos) + + def tokenize(self, sequence: str) -> list: + """ + Function that tokenizes a string. + - To be extended by subclasses. + """ + raise NotImplementedError + + def encode(self, sequence: str) -> torch.Tensor: + """ Encodes a 'sequence'. + :param sequence: String 'sequence' to encode. + + Returns: + - torch.Tensor: Encoding of the 'sequence'. + """ + sequence = super().encode(sequence) + tokens = self.tokenize(sequence) + vector = [self.stoi.get(token, self.unk_index) for token in tokens] + return torch.tensor(vector) + + def batch_encode(self, iterator, dim=0, **kwargs) -> (torch.Tensor, torch.Tensor): + """ + :param iterator (iterator): Batch of text to encode. + :param dim (int, optional): Dimension along which to concatenate tensors. + :param **kwargs: Keyword arguments passed to 'encode'. + + Returns + torch.Tensor, torch.Tensor: Encoded and padded batch of sequences; Original lengths of + sequences. + """ + return stack_and_pad_tensors( + Encoder.batch_encode(self, iterator, **kwargs), + padding_index=self.padding_index, + dim=dim, + ) + + def encode_trackpos(self, sequence: str) -> torch.Tensor: + """ Encodes a 'sequence' and keeps track of the position of the beginning of each token before wordpieces. + :param sequence: String 'sequence' to encode. + + Returns: + - torch.Tensor: Encoding of the 'sequence'. + - torch.Tensor: Alignment indexes + """ + raise NotImplementedError + + def batch_encode_trackpos( + self, iterator, dim=0, **kwargs + ) -> (torch.Tensor, torch.Tensor): + """ + :param iterator (iterator): Batch of text to encode. + :param dim (int, optional): Dimension along which to concatenate tensors. + :param **kwargs: Keyword arguments passed to 'encode'. + + Returns + torch.Tensor, torch.Tensor: Encoded and padded batch of sequences; Original lengths of + sequences. + """ + sequences, tags = zip(*[self.encode_trackpos(object_) for object_ in iterator]) + sequences, seq_lengths = stack_and_pad_tensors( + sequences, padding_index=self.padding_index, dim=dim + ) + tag_idxs, tag_lengths = stack_and_pad_tensors(tags, padding_index=0, dim=dim) + return sequences, seq_lengths, tag_idxs, tag_lengths + + def get_special_tokens_mask(self, tokens): + """ Function from Hugging face to train language models. """ + return list( + map(lambda x: 1 if x in [self.bos_index, self.eos_index] else 0, tokens) + ) diff --git a/caption/training.py b/caption/training.py new file mode 100644 index 0000000..b43d5ab --- /dev/null +++ b/caption/training.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +import logging +import os + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from test_tube import HyperOptArgumentParser + +from caption.utils import setup_testube_logger + + +log = logging.getLogger("Shell") + + +def setup_training(hparams: HyperOptArgumentParser) -> tuple: + """ + Setup for the training loop. + :param hparams: HyperOptArgumentParser + + Returns: + - pytorch_lightning Trainer + """ + if hparams.verbose: + log.info(hparams) + + if hparams.early_stopping: + # Enable Early stopping + early_stop_callback = EarlyStopping( + monitor=hparams.monitor, + min_delta=hparams.min_delta, + patience=hparams.patience, + verbose=hparams.verbose, + mode=hparams.metric_mode, + ) + else: + early_stop_callback = None + + # configure trainer + if hparams.epochs > 0.0: + hparams.min_epochs = hparams.epochs + hparams.max_epochs = hparams.epochs + + trainer = Trainer( + logger=setup_testube_logger(), + checkpoint_callback=True, + early_stop_callback=early_stop_callback, + default_save_path="experiments/", + gradient_clip_val=hparams.gradient_clip_val, + gpus=hparams.gpus, + show_progress_bar=False, + overfit_pct=hparams.overfit_pct, + check_val_every_n_epoch=hparams.check_val_every_n_epoch, + fast_dev_run=False, + accumulate_grad_batches=hparams.accumulate_grad_batches, + max_epochs=hparams.max_epochs, + min_epochs=hparams.min_epochs, + train_percent_check=hparams.train_percent_check, + val_percent_check=hparams.val_percent_check, + val_check_interval=hparams.val_check_interval, + log_save_interval=hparams.log_save_interval, + row_log_interval=hparams.row_log_interval, + distributed_backend=hparams.distributed_backend, + precision=hparams.precision, + weights_summary=hparams.weights_summary, + resume_from_checkpoint=hparams.resume_from_checkpoint, + profiler=hparams.profiler, + log_gpu_memory="all", + ) + + ckpt_path = os.path.join( + trainer.default_save_path, + trainer.logger.name, + f"version_{trainer.logger.version}", + "checkpoints", + ) + + # initialize Model Checkpoint Saver + checkpoint_callback = ModelCheckpoint( + filepath=ckpt_path, + save_top_k=hparams.save_top_k, + verbose=hparams.verbose, + monitor=hparams.monitor, + save_weights_only=hparams.save_weights_only, + period=hparams.period, + mode=hparams.metric_mode, + ) + trainer.checkpoint_callback = checkpoint_callback + return trainer + + +def add_trainer_specific_args(parser: HyperOptArgumentParser) -> HyperOptArgumentParser: + parser.add_argument("--seed", default=3, type=int, help="Training seed.") + parser.add_argument( + "--batch_size", default=32, type=int, help="Batch size to be used." + ) + parser.add_argument( + "--resume_from_checkpoint", + default=None, + type=str, + help=( + "To resume training from a specific checkpoint pass in the path here." + "(e.g. 'some/path/to/my_checkpoint.ckpt')" + ), + ) + parser.add_argument( + "--load_weights", + default=None, + type=str, + help=( + "Loads the model weights from a given checkpoint. " + "This flag differs from resume_from_checkpoint beacuse it only loads the" + "weights that match between the checkpoint model and the model we want to train. " + "It does not resume the entire training session (model/optimizer/scheduler etc..)." + ), + ) + parser.add_argument( + "--save_top_k", + default=1, + type=int, + help="The best k models according to the quantity monitored will be saved.", + ) + parser.add_argument( + "--monitor", default="pearson", type=str, help="Quantity to monitor." + ) + parser.add_argument( + "--metric_mode", + default="max", + type=str, + help=( + "One of {min, max}. If `--save_best_only`, the decision to " + "overwrite the current checkpoint is based on either the maximization " + "or the minimization of the monitored quantity." + ), + choices=["min", "max"], + ) + parser.add_argument( + "--period", + default=1, + type=int, + help="Interval (number of epochs) between checkpoints.", + ) + parser.add_argument( + "--save_weights_only", + default=False, + help="If True, then only the model's weights will be saved.", + action="store_true", + ) + parser.add_argument( + "--verbose", + default=False, + help="Verbosity mode, True or False", + action="store_true", + ) + # Early Stopping + parser.add_argument( + "--early_stopping", + default=False, + help="If set to True Early Stopping is enabled.", + action="store_true", + ) + parser.add_argument( + "--patience", + default=3, + type=int, + help=( + "Number of epochs with no improvement " + "after which training will be stopped." + ), + ) + parser.add_argument( + "--min_delta", + default=0.0, + type=float, + help="Minimum change in the monitored quantity.", + ) + # pytorch-lightning trainer class specific args + parser.add_argument( + "--epochs", + default=-1, + type=int, + help=( + "Number of epochs to run. By default the number of epochs " + "is controlled but the min_nb_epochs and max_nb_epochs parameters." + ), + ) + parser.add_argument( + "--min_epochs", + default=1, + type=int, + help="Limits training to a minimum number of epochs", + ) + parser.add_argument( + "--max_epochs", + default=3, + type=int, + help="Limits training to a max number number of epochs", + ) + parser.add_argument( + "--accumulate_grad_batches", + default=1, + type=int, + help=( + "Accumulated gradients runs K small batches of size N before " + "doing a backwards pass. The effect is a large effective batch " + "size of size KxN." + ), + ) + parser.add_argument( + "--gradient_clip_val", + default=1.0, + type=float, + help="Max norm of the gradients.", + ) + parser.add_argument( + "--gpus", + default=1, + type=int, + help="ie: 2 gpus OR -1 to use all available gpus", + ) + parser.add_argument( + "--distributed_backend", + default="dp", + type=str, + help="Options: 'dp' (lightning ddp and ddp2 not working!)", + choices=["dp"], + ) + parser.add_argument( + "--precision", + default=32, + type=int, + help="Full precision (32), half precision (16).", + choices=[16, 32], + ) + parser.add_argument( + "--log_save_interval", + default=100, + type=int, + help="Writes logs to disk this often", + ) + parser.add_argument( + "--row_log_interval", + default=10, + type=int, + help="How often to add logging rows (does not write to disk)", + ) + parser.add_argument( + "--check_val_every_n_epoch", + default=1, + type=int, + help="Check val every n train epochs.", + ) + parser.add_argument( + "--train_percent_check", + default=1.0, + type=float, + help=( + "If you don't want to use the entire training set, " + "set how much of the train set you want to use with this flag." + ), + ) + parser.add_argument( + "--val_percent_check", + default=1.0, + type=float, + help=( + "If you don't want to use the entire dev set, set how much of the dev " + "set you want to use with this flag." + ), + ) + parser.add_argument( + "--val_check_interval", + default=1.0, + type=float, + help=( + "For large datasets it's often desirable to check validation multiple " + "times within a training loop. Pass in a float to check that often " + "within 1 training epoch." + ), + ) + parser.add_argument( + "--train_val_percent_check", + default=0.01, + type=float, + help=( + "In the end of each epoch a subset of the training data will be selected " + "to measure performance against training. Pass a float to set how much of " + "the training data you want to use" + ), + ) + # Debugging + parser.add_argument( + "--overfit_pct", + default=0.0, + type=float, + help=( + "A useful debugging trick is to make your model overfit a tiny fraction " + "of the data. Default: don't overfit (ie: normal training)" + ), + ) + parser.add_argument( + "--weights_summary", + default="full", + type=str, + help="Prints a summary of the weights when training begins.", + choices=["full", "top"], + ) + parser.add_argument( + "--profiler", + default=False, + help="If you only wish to profile the standard actions during training.", + action="store_true", + ) + parser.add_argument( + "--log_gpu_memory", + default=None, + type=str, + help="Logs (to a logger) the GPU usage for each GPU on the master machine.", + choices=["min_max", "full"], + ) + return parser diff --git a/caption/utils.py b/caption/utils.py new file mode 100644 index 0000000..21f39ea --- /dev/null +++ b/caption/utils.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +import os +import random +import shutil +from datetime import datetime + +import numpy as np +import torch +import yaml + +from pytorch_lightning.logging import TestTubeLogger +from test_tube import HyperOptArgumentParser +from test_tube.argparse_hopt import TTNamespace + + +def load_yaml_args(parser: HyperOptArgumentParser, log): + """ Function that load the args defined in a YAML file and replaces the values + parsed by the HyperOptArgumentParser """ + old_args = vars(parser.parse_args()) + configs = old_args.get("config") + if configs: + yaml_file = yaml.load(open(configs).read(), Loader=yaml.FullLoader) + for key, value in yaml_file.items(): + if key in old_args: + old_args[key] = value + else: + raise Exception( + "{} argument defined in {} is not valid!".format(key, configs) + ) + else: + log.warning( + "We recommend the usage of YAML files to keep track \ + of the hyperparameter during testing and training." + ) + return TTNamespace(**old_args) + + +def get_main_args_from_yaml(args): + """ Function for loading the __main__ arguments directly from the YAML """ + if not args.config: + raise Exception("You must pass a YAML file if not using the command line.") + try: + yaml_file = yaml.load(open(args.config).read(), Loader=yaml.FullLoader) + return yaml_file["optimizer"], yaml_file["scheduler"], yaml_file["model"] + except KeyError as e: + raise Exception("YAML file is missing the {} parameter.".format(e.args[0])) + + +def setup_testube_logger(): + """ Function that sets the TestTubeLogger to be used. """ + try: + job_id = os.environ["SLURM_JOB_ID"] + except Exception: + job_id = None + + now = datetime.now() + dt_string = now.strftime("%d-%m-%Y--%H-%M-%S") + return TestTubeLogger( + save_dir="experiments/", + version=job_id if job_id else dt_string, + name="lightning_logs", + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6c817d4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,13 @@ +scipy==1.4.* +numpy==1.17.* +pandas==1.0.* +PyYAML==5.3.* +tensorboard==2.0.2 +torch==1.3.1 +tqdm==4.43.* +transformers==2.3.0 +fairseq==0.9.* +pytorch-lightning==0.7.1 +test-tube==0.7.5 +pytorch-nlp==0.5.0 +scikit-learn==0.22.2.post1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4688670 --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +from setuptools import find_packages, setup + +setup( + name="caption", + version="1.0.0", + author="Ricardo Rei, Nuno Miguel Guerreiro", + download_url="https://gitlab.com/Unbabel/discovery-team/caption", + author_email="ricardo.rei@unbabel.com, nuno.guerreiro@unbabel.com", + packages=find_packages(exclude=["tests"]), + description="Provides automatic transcription enrichment for ASR data", + keywords=["Deep Learning", "PyTorch", "AI", "NLP", "Natural Language Processing"], + python_requires=">=3.6", + setup_requires=[], + install_requires=[ + line.strip() for line in open("requirements.txt", "r").readlines() + ], +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/lightning_models/__init__.py b/tests/lightning_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/lightning_models/test_taggers/__init__.py b/tests/lightning_models/test_taggers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/lightning_models/test_taggers/test_transformer_tagger.py b/tests/lightning_models/test_taggers/test_transformer_tagger.py new file mode 100644 index 0000000..38d4aab --- /dev/null +++ b/tests/lightning_models/test_taggers/test_transformer_tagger.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- +r""" +Test TransformerTagger integration with Lightning +============== +""" +import unittest +from unittest.mock import Mock + +import torch + +from caption.models.taggers import TransformerTagger +from test_tube import HyperOptArgumentParser + + +class TestTransformerTagger(unittest.TestCase): + @property + def hparams(self): + parser = HyperOptArgumentParser() + # metric mode and monitor are hparams required by COMET models + # and lightning trainer. + parser.add_argument("--monitor", default="slot_error_rate") + parser.add_argument("--metric_mode", default="min") + parser = TransformerTagger.add_model_specific_args(parser) + hparams, _ = parser.parse_known_args([]) + return hparams + + @property + def samples(self): + """ Sample example """ + return [ + {"text": "hello world", "tags": "T L",}, + {"text": "how amazing is caption", "tags": "T L L U",}, + ] + + def setUp(self): + """ Setup will test """ + self.model = TransformerTagger(self.hparams) + + def test_run_model(self): + """ Function that tests the integration between: + prepare_sample, forward and _compute_loss + """ + inputs, targets = self.model.prepare_sample(self.samples) + predictions = self.model(**inputs) + loss_value = self.model._compute_loss(predictions, targets) + self.assertIsInstance(loss_value, torch.Tensor) + + def test_training_step(self): + """ Function that tests the integration between: + prepare_sample, forward and _compute_loss. + """ + # Single-GPU Distributed Parallel training step + self.model.trainer = Mock(use_dp=True, num_gpus=1) + batch = self.model.prepare_sample(self.samples) + result = self.model.training_step(batch, 0) + assert "loss" in result + + # MultiGPU Distributed Parallel training step + self.model.trainer = Mock(use_dp=True, num_gpus=2) + batch = self.model.prepare_sample(self.samples) + result = self.model.training_step(batch, 0) + assert "loss" in result + + def test_validation_step(self): + """ Function that tests the integration between: + prepare_sample, forward and _compute_loss in validation. + """ + self.model.eval() + self.model.trainer = Mock(use_dp=True, num_gpus=1) + batch = self.model.prepare_sample(self.samples) + result = self.model.validation_step(batch, 0, 0) + assert "val_loss" in result + assert "val_prediction" in result + assert "val_target" in result + + def test_validation_end(self): + """ Function that tests the integration between: + validation_step and validation_epoch_end + """ + self.model.eval() + self.model.trainer = Mock(use_dp=True, num_gpus=1) + + with torch.no_grad(): + # Simulation of the first validation dataloader. + outputs_first_dataloader = [] + for i in range(5): + batch = self.model.prepare_sample(self.samples) + outputs_first_dataloader.append(self.model.validation_step(batch, 0, 0)) + + # Simulation of the second validation dataloader. + outputs_second_dataloader = [] + for i in range(5): + batch = self.model.prepare_sample(self.samples) + outputs_second_dataloader.append( + self.model.validation_step(batch, 0, 1) + ) + + outputs = (outputs_first_dataloader, outputs_second_dataloader) + result = self.model.validation_epoch_end(outputs) + self.assertIsInstance(result["log"]["val_loss"], torch.Tensor) + self.assertIsInstance(result["log"]["train_loss"], torch.Tensor) + + assert "slot_error_rate" in result["log"].keys() + assert "L_f1_score" in result["log"].keys() + assert "U_f1_score" in result["log"].keys() + assert "T_f1_score" in result["log"].keys() + assert "macro_fscore" in result["log"].keys() + assert "micro_fscore" in result["log"].keys() + + def test_predict(self): + sample = self.samples[0] + result = self.model.predict(sample) + assert result["text"] == sample["text"] + assert result["tags"] == sample["tags"] + assert "predicted_tags" in result.keys() + assert "tagged_sequence" in result.keys() + + sample = self.samples + result = self.model.predict(sample) + assert len(result) == 2 + for i in range(2): + assert result[i]["text"] == sample[i]["text"] + assert result[i]["tags"] == sample[i]["tags"] + assert "predicted_tags" in result[i].keys() + assert "tagged_sequence" in result[i].keys() diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..0d9a5df --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +import unittest + +import numpy as np + +from caption.models.metrics import ( + classification_report, + confusion_matrix, + fscore, + precision, + recall, +) + + +class TestMetrics(unittest.TestCase): + @property + def y_pred(self): + return np.array([0, 0, 1, 2, 2, 1, 1, 0, 1, 0]) + + @property + def y(self): + return np.array([0, 1, 1, 2, 1, 2, 0, 2, 3, 3]) + + def test_confusion_matrix(self): + cm = confusion_matrix(self.y_pred, self.y, padding=3) + expected = np.array([[1, 1, 1, 0], [1, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 2]]) + assert np.array_equal(cm, expected) + + def test_recall(self): + tp = 3 + fp = 1 + fn = 1 + assert recall(tp, fp, fn) == 3 / 4 + + def test_precision(self): + tp = 3 + fp = 1 + fn = 1 + assert precision(tp, fp, fn) == 3 / 4 + + def test_fscore(self): + tp = 3 + fp = 1 + fn = 1 + assert fscore(tp, fp, fn) == 2 * ( + precision(tp, fp, fn) * recall(tp, fp, fn) + ) / (recall(tp, fp, fn) + precision(tp, fp, fn)) + + def test_classification_report(self): + out = classification_report( + self.y_pred, self.y, padding=3, labels={"L": 0, "U": 1, "T": 2}, ignore=0 + ) + assert out["L_f1_score"] == 2 * ((1 / 3) * (1 / 2)) / ((1 / 3) + (1 / 2)) + assert out["U_f1_score"] == 2 * ((1 / 3) * (1 / 3)) / ((1 / 3) + (1 / 3)) + assert out["T_f1_score"] == 2 * ((1 / 3) * (1 / 2)) / ((1 / 3) + (1 / 2)) + + assert out["macro_fscore"] == (out["U_f1_score"] + out["T_f1_score"]) / 2 + assert out["micro_fscore"] != 0 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_encoders/__init__.py b/tests/unit/test_encoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_encoders/test_bert_encoder.py b/tests/unit/test_encoders/test_bert_encoder.py new file mode 100644 index 0000000..e806631 --- /dev/null +++ b/tests/unit/test_encoders/test_bert_encoder.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +import unittest +from argparse import Namespace + +import torch +from transformers import BertTokenizer + +from caption.models.encoders import BERT + + +class TestBERTEncoder(unittest.TestCase): + def setUp(self): + # setup tests the from_pretrained function + hparams = Namespace(pretrained_model="bert-base-cased") + self.model_base = BERT.from_pretrained(hparams) + self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + + hparams = Namespace(pretrained_model="bert-large-cased") + self.model_large = BERT.from_pretrained(hparams) + + def test_num_layers(self): + assert self.model_base.num_layers == 13 + assert self.model_large.num_layers == 25 + + def test_output_units(self): + assert self.model_base.output_units == 768 + assert self.model_large.output_units == 1024 + + def test_prepare_sample(self): + sample = ["hello world, welcome to COMET!", "This is a batch"] + + model_input = self.model_base.prepare_sample(sample) + assert "tokens" in model_input + assert "lengths" in model_input + + # Sanity Check: This is already checked when testing the tokenizer. + expected = self.tokenizer.encode(sample[0]) + assert torch.equal(torch.tensor(expected), model_input["tokens"][0]) + assert len(expected) == model_input["lengths"][0] + + model_input = self.model_base.prepare_sample(sample, trackpos=True) + assert "tokens" in model_input + assert "lengths" in model_input + assert "word_boundaries" in model_input + assert "word_lengths" in model_input + + def test_forward(self): + sample = ["hello world!", "This is a batch"] + model_input = self.model_base.prepare_sample(sample) + model_out = self.model_base(**model_input) + + assert "wordemb" in model_out + assert "sentemb" in model_out + assert "all_layers" in model_out + assert "mask" in model_out + assert "extra" in model_out + + assert len(model_out["all_layers"]) == self.model_base.num_layers + assert self.model_base.output_units == model_out["sentemb"].size()[1] + assert self.model_base.output_units == model_out["wordemb"].size()[2] diff --git a/tests/unit/test_optimizers/__init__.py b/tests/unit/test_optimizers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_optimizers/test_adam.py b/tests/unit/test_optimizers/test_adam.py new file mode 100644 index 0000000..1383e13 --- /dev/null +++ b/tests/unit/test_optimizers/test_adam.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +import unittest + +import torch + +from caption.optimizers import Adam +from test_tube import HyperOptArgumentParser + + +class TestAdamOptimizer(unittest.TestCase): + @property + def parameters(self): + return [torch.nn.Parameter(torch.randn(2, 3, 4))] + + @property + def hparams(self): + parser = HyperOptArgumentParser() + parser = Adam.add_optim_specific_args(parser) + return parser.parse_args() + + def test_from_hparams(self): + assert Adam.from_hparams(self.parameters, self.hparams) diff --git a/tests/unit/test_optimizers/test_adamax.py b/tests/unit/test_optimizers/test_adamax.py new file mode 100644 index 0000000..4a747a7 --- /dev/null +++ b/tests/unit/test_optimizers/test_adamax.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +import unittest + +import torch + +from caption.optimizers import Adamax +from test_tube import HyperOptArgumentParser + + +class TestAdamaxOptimizer(unittest.TestCase): + @property + def parameters(self): + return [torch.nn.Parameter(torch.randn(2, 3, 4))] + + @property + def hparams(self): + parser = HyperOptArgumentParser() + parser = Adamax.add_optim_specific_args(parser) + return parser.parse_args() + + def test_from_hparams(self): + assert Adamax.from_hparams(self.parameters, self.hparams) diff --git a/tests/unit/test_optimizers/test_adamw.py b/tests/unit/test_optimizers/test_adamw.py new file mode 100644 index 0000000..fbb4612 --- /dev/null +++ b/tests/unit/test_optimizers/test_adamw.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +import unittest + +import torch + +from caption.optimizers import AdamW +from test_tube import HyperOptArgumentParser + + +class TestAdamWOptimizer(unittest.TestCase): + @property + def parameters(self): + return [torch.nn.Parameter(torch.randn(2, 3, 4))] + + @property + def hparams(self): + parser = HyperOptArgumentParser() + parser = AdamW.add_optim_specific_args(parser) + return parser.parse_args() + + def test_from_hparams(self): + assert AdamW.from_hparams(self.parameters, self.hparams) diff --git a/tests/unit/test_optimizers/test_radam.py b/tests/unit/test_optimizers/test_radam.py new file mode 100644 index 0000000..16b4120 --- /dev/null +++ b/tests/unit/test_optimizers/test_radam.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +import unittest + +import torch + +from caption.optimizers import RAdam +from test_tube import HyperOptArgumentParser + + +class TestRAdamOptimizer(unittest.TestCase): + @property + def parameters(self): + return [torch.nn.Parameter(torch.randn(2, 3, 4))] + + @property + def hparams(self): + parser = HyperOptArgumentParser() + parser = RAdam.add_optim_specific_args(parser) + return parser.parse_args() + + def test_from_hparams(self): + assert RAdam.from_hparams(self.parameters, self.hparams) diff --git a/tests/unit/test_schedulers/__init__.py b/tests/unit/test_schedulers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_schedulers/test_constant_scheduler.py b/tests/unit/test_schedulers/test_constant_scheduler.py new file mode 100644 index 0000000..2a533d3 --- /dev/null +++ b/tests/unit/test_schedulers/test_constant_scheduler.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +import unittest + +import torch + +from caption.schedulers import ConstantLR +from test_tube import HyperOptArgumentParser + + +class TestConstantLRScheduler(unittest.TestCase): + @property + def optimizer(self): + params = [torch.nn.Parameter(torch.randn(2, 3, 4))] + return torch.optim.Adam(params) + + @property + def hparams(self): + parser = HyperOptArgumentParser() + parser = ConstantLR.add_scheduler_specific_args(parser) + return parser.parse_args() + + def test_scheduler_init(self): + assert ConstantLR.from_hparams(self.optimizer, self.hparams) diff --git a/tests/unit/test_schedulers/test_linear_warmup_scheduler.py b/tests/unit/test_schedulers/test_linear_warmup_scheduler.py new file mode 100644 index 0000000..e92867e --- /dev/null +++ b/tests/unit/test_schedulers/test_linear_warmup_scheduler.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +import unittest + +import torch + +from caption.schedulers import LinearWarmup +from test_tube import HyperOptArgumentParser + + +class TestLinearWarmupScheduler(unittest.TestCase): + @property + def optimizer(self): + params = [torch.nn.Parameter(torch.randn(2, 3, 4))] + return torch.optim.Adam(params) + + @property + def hparams(self): + parser = HyperOptArgumentParser() + parser = LinearWarmup.add_scheduler_specific_args(parser) + return parser.parse_args() + + def test_scheduler_init(self): + assert LinearWarmup.from_hparams(self.optimizer, self.hparams) diff --git a/tests/unit/test_schedulers/test_warmup_constant_scheduler.py b/tests/unit/test_schedulers/test_warmup_constant_scheduler.py new file mode 100644 index 0000000..52a3b3f --- /dev/null +++ b/tests/unit/test_schedulers/test_warmup_constant_scheduler.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +import unittest + +import torch + +from caption.schedulers import WarmupConstant +from test_tube import HyperOptArgumentParser + + +class TestWarmupConstantScheduler(unittest.TestCase): + @property + def optimizer(self): + params = [torch.nn.Parameter(torch.randn(2, 3, 4))] + return torch.optim.Adam(params) + + @property + def hparams(self): + parser = HyperOptArgumentParser() + parser = WarmupConstant.add_scheduler_specific_args(parser) + return parser.parse_args() + + def test_scheduler_init(self): + assert WarmupConstant.from_hparams(self.optimizer, self.hparams) diff --git a/tests/unit/test_tokenizers/__init__.py b/tests/unit/test_tokenizers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_tokenizers/test_bert_tokenizer.py b/tests/unit/test_tokenizers/test_bert_tokenizer.py new file mode 100644 index 0000000..a94ad99 --- /dev/null +++ b/tests/unit/test_tokenizers/test_bert_tokenizer.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +import unittest + +import torch +from transformers import BertTokenizer + +from caption.tokenizers import BERTTextEncoder + + +class TestBERTTextEncoder(unittest.TestCase): + def setUp(self): + self.tokenizer = BERTTextEncoder("bert-base-multilingual-cased") + self.original_tokenizer = BertTokenizer.from_pretrained( + "bert-base-multilingual-cased" + ) + + def test_unk_property(self): + assert self.tokenizer.unk_index == 100 + + def test_pad_property(self): + assert self.tokenizer.padding_index == 0 + + def test_bos_property(self): + assert self.tokenizer.bos_index == 101 + + def test_eos_property(self): + assert self.tokenizer.eos_index == 102 + + def test_mask_property(self): + assert self.tokenizer.mask_index == 103 + + def test_vocab_property(self): + assert isinstance(self.tokenizer.vocab, dict) + + def test_vocab_size_property(self): + assert self.tokenizer.vocab_size > 0 + + def test_get_special_tokens_mask(self): + tensor = torch.tensor([101, 35378, 759, 10269, 56112, 25, 18, 91010, 297, 102]) + mask = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 1]) + result = self.tokenizer.get_special_tokens_mask(tensor) + assert torch.equal(mask, torch.tensor(result)) + + tensor = torch.tensor([101, 35378, 759, 102, 25, 18, 91010, 297, 102]) + mask = torch.tensor([1, 0, 0, 1, 0, 0, 0, 0, 1]) + result = self.tokenizer.get_special_tokens_mask(tensor) + assert torch.equal(mask, torch.tensor(result)) + + def test_encode(self): + sentence = "Hello, my dog is cute" + expected = self.original_tokenizer.encode(sentence) + result = self.tokenizer.encode(sentence) + assert torch.equal(torch.tensor(expected), result) + # Make sure the bos and eos tokens were added. + assert result[0] == self.tokenizer.bos_index + assert result[-1] == self.tokenizer.eos_index + + def test_batch_encode(self): + # Test batch_encode. + batch = ["Hello, my dog is cute", "hello world!"] + encoded_batch, lengths = self.tokenizer.batch_encode(batch) + + assert torch.equal(encoded_batch[0], self.tokenizer.encode(batch[0])) + assert torch.equal( + encoded_batch[1][: lengths[1]], self.tokenizer.encode(batch[1]) + ) + assert lengths[0] == len( + self.original_tokenizer.encode("Hello, my dog is cute") + ) + assert lengths[1] == len(self.original_tokenizer.encode("hello world!")) + + # Check if last sentence is padded. + assert encoded_batch[1][-1] == self.tokenizer.padding_index + assert encoded_batch[1][-2] == self.tokenizer.padding_index + + def test_encode_trackpos(self): + sentence = "Hello my dog isn't retarded" + result = self.tokenizer.encode_trackpos(sentence) + + # [CLS] hello my dog isn ' t ret ##arde ##d [SEP]" + # 0 1 2 3 4 5 6 7 8 9 + expected = ( + torch.tensor( + [101, 31178, 15127, 17835, 98370, 112, 188, 62893, 45592, 10162, 102] + ), + torch.tensor([1, 2, 3, 4, 7]), + ) + assert torch.equal(result[0], expected[0]) + assert torch.equal(result[1], expected[1]) + + def test_batch_encode_trackpos(self): + batch = ["Hello my dog isn't retarded", "retarded"] + result = self.tokenizer.batch_encode_trackpos(batch) + + # [CLS] hello my dog isn ' t ret ##arde ##d [SEP]" + # 0 1 2 3 4 5 6 7 8 9 + expected1 = ( + torch.tensor( + [101, 31178, 15127, 17835, 98370, 112, 188, 62893, 45592, 10162, 102] + ), + torch.tensor([1, 2, 3, 4, 7]), + ) + + # [CLS] ret ##arde ##d [SEP]" + # 0 1 2 3 + expected2 = ( + torch.tensor([101, 62893, 45592, 10162, 102, 0, 0, 0, 0, 0, 0]), + torch.tensor([1, 0, 0, 0, 0]), + ) + + assert torch.equal(result[0][0], expected1[0]) + assert torch.equal(result[0][1], expected2[0]) + assert torch.equal(result[2][0], expected1[1]) + assert torch.equal(result[2][1], expected2[1]) diff --git a/tests/unit/test_tokenizers/test_roberta_tokenizer.py b/tests/unit/test_tokenizers/test_roberta_tokenizer.py new file mode 100644 index 0000000..db50eee --- /dev/null +++ b/tests/unit/test_tokenizers/test_roberta_tokenizer.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +import os +import unittest + +import torch + +from caption.tokenizers import RoBERTaTextEncoder +from fairseq.models.roberta import RobertaModel +from test_tube import HyperOptArgumentParser +from torchnlp.download import download_file_maybe_extract + +download_file_maybe_extract( + "https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz", + directory=os.environ["HOME"] + "/.cache/caption/", + check_files=["roberta.base/model.pt"], +) +roberta = RobertaModel.from_pretrained( + os.environ["HOME"] + "/.cache/caption/roberta.base", checkpoint_file="model.pt", +) +original_vocab = roberta.task.source_dictionary.__dict__["indices"] +tokenizer = RoBERTaTextEncoder(roberta.encode, original_vocab) + + +class TestRoBERTaTextEncoder(unittest.TestCase): + def test_unk_property(self): + assert tokenizer.unk_index == original_vocab[""] + + def test_pad_property(self): + assert tokenizer.padding_index == original_vocab[""] + + def test_bos_property(self): + assert tokenizer.bos_index == original_vocab[""] + + def test_eos_property(self): + assert tokenizer.eos_index == original_vocab[""] + + def test_mask_property(self): + assert tokenizer.mask_index == original_vocab[""] + + def test_vocab_property(self): + assert isinstance(tokenizer.vocab, dict) + + def test_vocab_size_property(self): + assert tokenizer.vocab_size == len(original_vocab) + + def test_get_special_tokens_mask(self): + tensor = torch.tensor([0, 35378, 759, 10269, 56112, 25, 18, 91010, 297, 2]) + mask = torch.tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 1]) + result = tokenizer.get_special_tokens_mask(tensor) + assert torch.equal(mask, torch.tensor(result)) + + tensor = torch.tensor([0, 35378, 759, 2, 2, 25, 18, 91010, 297, 2]) + mask = torch.tensor([1, 0, 0, 1, 1, 0, 0, 0, 0, 1]) + result = tokenizer.get_special_tokens_mask(tensor) + assert torch.equal(mask, torch.tensor(result)) + + def test_encode(self): + sentence = "Hello, my dog is cute" + expected = roberta.encode(sentence) + result = tokenizer.encode(sentence) + assert torch.equal(expected, result) + # Make sure the bos and eos tokens were added. + assert result[0] == tokenizer.bos_index + assert result[-1] == tokenizer.eos_index + + def test_batch_encode(self): + # Test batch_encode. + batch = ["Hello, my dog is cute", "hello world!"] + encoded_batch, lengths = tokenizer.batch_encode(batch) + + assert torch.equal(encoded_batch[0], tokenizer.encode(batch[0])) + assert torch.equal(encoded_batch[1][: lengths[1]], tokenizer.encode(batch[1])) + assert lengths[0] == len(roberta.encode("Hello, my dog is cute")) + assert lengths[1] == len(roberta.encode("hello world!")) + + # Check if last sentence is padded. + assert encoded_batch[1][-1] == tokenizer.padding_index + assert encoded_batch[1][-2] == tokenizer.padding_index + + def test_encode_trackpos(self): + sentence = "Hello my dog isn't retarded" + result = tokenizer.encode_trackpos(sentence) + + # "" "Hello" " my" " dog" " isn" "'t" " retarded" "" + # 0 1 2 3 4 5 6 7 + expected = ( + torch.tensor([0, 31414, 127, 2335, 965, 75, 47304, 2]), + torch.tensor([1, 2, 3, 4, 6]), + ) + assert torch.equal(result[0], expected[0]) + assert torch.equal(result[1], expected[1]) + + def test_batch_encode_trackpos(self): + batch = ["Hello my dog isn't retarded", "retarded"] + result = tokenizer.batch_encode_trackpos(batch) + + # "" "Hello" " my" " dog" " isn" "'t" " retarded" "" + # 0 1 2 3 4 5 6 7 + expected1 = ( + torch.tensor([0, 31414, 127, 2335, 965, 75, 47304, 2]), + torch.tensor([1, 2, 3, 4, 6]), + ) + # NOTE: since retarded appears in the begin the bpe encoder will + # not encode it as " retarded" anymore. + # "" "ret" "arded" "" + # 0 1 2 + expected2 = ( + torch.tensor([0, 4903, 16230, 2, 1, 1, 1, 1]), + torch.tensor([1, 0, 0, 0, 0]), + ) + assert torch.equal(result[0][0], expected1[0]) + assert torch.equal(result[0][1], expected2[0]) + assert torch.equal(result[2][0], expected1[1]) + assert torch.equal(result[2][1], expected2[1])