From 213101704693331a83de0a99377a0367759e1af0 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Thu, 3 Oct 2024 17:18:34 +0100 Subject: [PATCH 01/21] dummy change --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 31f1d1d4..d5191245 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ + ![Maturity level-1](https://img.shields.io/badge/Maturity%20Level-ML--2-green)

From e50240aa10ca8bdbda17893c79678eec5de5f2ae Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Thu, 3 Oct 2024 17:34:23 +0100 Subject: [PATCH 02/21] try without pip cache --- .github/workflows/precommit_and_docs_build_check.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/precommit_and_docs_build_check.yml b/.github/workflows/precommit_and_docs_build_check.yml index 92ba8bb8..2bc80bcc 100644 --- a/.github/workflows/precommit_and_docs_build_check.yml +++ b/.github/workflows/precommit_and_docs_build_check.yml @@ -16,7 +16,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.11' + python-version: "3.11" cache: pip - name: create venv and install dependencies @@ -29,7 +29,7 @@ jobs: . /tmp/kazu-env/bin/activate python -m pip install --upgrade pip pip install --upgrade --upgrade-strategy eager --index-url https://download.pytorch.org/whl/cpu "torch>=2.0.0" - pip install -e ."[dev]" --cache-dir $PIP_CACHE_DIR --upgrade --upgrade-strategy eager + pip install -e ."[dev]" - name: Check precommit run: | @@ -39,7 +39,7 @@ jobs: - name: Check docs build successfully # still run even if the pre-commit fails, so we # have the output on if the docs build succeeded or not - if: '!cancelled()' + if: "!cancelled()" run: | . /tmp/kazu-env/bin/activate make -C docs html From b728d5bafd8f5089d218c03a942b5a9dae9f67bb Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Thu, 3 Oct 2024 17:48:10 +0100 Subject: [PATCH 03/21] try forcing pandas>2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ab447f3a..f7698066 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "rdflib>=6.0.0", "requests>=2.20.0", "hydra-core>=1.3.0", - "pandas>=1.0.0", + "pandas>=2.0.0", "pyarrow", "pyahocorasick", "pymongo>=4.3.3", From dba74c029684dbe6d7ed7d1adcf27e523f031007 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Fri, 4 Oct 2024 09:41:10 +0100 Subject: [PATCH 04/21] taste disabled pip cache --- .github/workflows/precommit_and_docs_build_check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/precommit_and_docs_build_check.yml b/.github/workflows/precommit_and_docs_build_check.yml index 2bc80bcc..63c436df 100644 --- a/.github/workflows/precommit_and_docs_build_check.yml +++ b/.github/workflows/precommit_and_docs_build_check.yml @@ -29,7 +29,7 @@ jobs: . /tmp/kazu-env/bin/activate python -m pip install --upgrade pip pip install --upgrade --upgrade-strategy eager --index-url https://download.pytorch.org/whl/cpu "torch>=2.0.0" - pip install -e ."[dev]" + pip install -e ."[dev]" --no-cache-dir - name: Check precommit run: | From 69d98bb6c75b27ebdba14f7333849f061760508a Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 09:32:28 +0100 Subject: [PATCH 05/21] test numpy > 2 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f7698066..9a6ab8ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "requests>=2.20.0", "hydra-core>=1.3.0", "pandas>=2.0.0", + "numpy>=2.0.0", "pyarrow", "pyahocorasick", "pymongo>=4.3.3", From 25045f6951149a89a87bfdb653534dfb02b2f3db Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 10:26:42 +0100 Subject: [PATCH 06/21] try installing openblas --- .github/workflows/precommit_and_docs_build_check.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/precommit_and_docs_build_check.yml b/.github/workflows/precommit_and_docs_build_check.yml index 63c436df..c4afcb3d 100644 --- a/.github/workflows/precommit_and_docs_build_check.yml +++ b/.github/workflows/precommit_and_docs_build_check.yml @@ -21,6 +21,8 @@ jobs: - name: create venv and install dependencies run: | + apt-get update + apt install libopenblas-dev rm -r /tmp/kazu-env || true PIP_CACHE_DIR=$(pip cache dir) From d7d205c5426990ed59778a19f459b1b4e1e73121 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 10:38:27 +0100 Subject: [PATCH 07/21] revert changes and try different scipy version --- .github/workflows/precommit_and_docs_build_check.yml | 2 -- pyproject.toml | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/precommit_and_docs_build_check.yml b/.github/workflows/precommit_and_docs_build_check.yml index c4afcb3d..63c436df 100644 --- a/.github/workflows/precommit_and_docs_build_check.yml +++ b/.github/workflows/precommit_and_docs_build_check.yml @@ -21,8 +21,6 @@ jobs: - name: create venv and install dependencies run: | - apt-get update - apt install libopenblas-dev rm -r /tmp/kazu-env || true PIP_CACHE_DIR=$(pip cache dir) diff --git a/pyproject.toml b/pyproject.toml index 9a6ab8ba..2ac4d07f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,7 @@ dependencies = [ "rdflib>=6.0.0", "requests>=2.20.0", "hydra-core>=1.3.0", - "pandas>=2.0.0", - "numpy>=2.0.0", + "pandas", "pyarrow", "pyahocorasick", "pymongo>=4.3.3", @@ -35,7 +34,7 @@ dependencies = [ "scikit-learn>=0.24.0", # scipy 1.12.0 introduced many changes to the sparse matrices api. https://docs.scipy.org/doc/scipy/reference/sparse.html#module-scipy.sparse # This is causing our acceptance tests to fail. Pinning to <1.12.0 until it's confirmed other libraries (e.g. sk-learn) don't have issues. - "scipy<1.12.0", + "scipy", "regex>=2020.1.7", "psutil>=5.3.0", "cachetools>=5.2.0", From a29425cf6b41dce4452c7602dca85df64326ff4e Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 11:02:32 +0100 Subject: [PATCH 08/21] try sudo install openblas --- .github/workflows/precommit_and_docs_build_check.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/precommit_and_docs_build_check.yml b/.github/workflows/precommit_and_docs_build_check.yml index 63c436df..f8951000 100644 --- a/.github/workflows/precommit_and_docs_build_check.yml +++ b/.github/workflows/precommit_and_docs_build_check.yml @@ -21,6 +21,8 @@ jobs: - name: create venv and install dependencies run: | + sudo apt-get update + sudo apt install libopenblas-dev rm -r /tmp/kazu-env || true PIP_CACHE_DIR=$(pip cache dir) From 16f0170b8bc9b3adab1dc6609232722c24371130 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 11:16:22 +0100 Subject: [PATCH 09/21] unpin types-requests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2ac4d07f..0bebfb0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ typed = [ # version 2.31.0.3 introduced # overly strict typing of requests.HttpError # which was fixed in 2.31.0.6 - "types-requests!=2.31.0.3,!=2.31.0.4,!=2.31.0.5", + "types-requests", "types-cachetools", "types-regex", "types-psutil", From 1255790678e795559424ec8c83cfe650eaa7a565 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 11:21:42 +0100 Subject: [PATCH 10/21] remove types-requests --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0bebfb0d..5e0c27ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,6 @@ typed = [ # version 2.31.0.3 introduced # overly strict typing of requests.HttpError # which was fixed in 2.31.0.6 - "types-requests", "types-cachetools", "types-regex", "types-psutil", From f03c0e1c02ebc8cb95d5d492008804552a06e255 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 12:09:23 +0100 Subject: [PATCH 11/21] feat: remove distillation training code causing issues --- .../precommit_and_docs_build_check.yml | 4 +- kazu/conf/DistillationTraining/default.yaml | 71 -- .../intermediatelayer.yaml | 72 -- .../conf/DistillationTraining/production.yaml | 72 -- kazu/conf/config.yaml | 5 +- kazu/distillation/__init__.py | 0 kazu/distillation/dataprocessor.py | 132 ---- kazu/distillation/lightning_plugins.py | 73 -- kazu/distillation/metrics.py | 44 -- kazu/distillation/models.py | 747 ------------------ kazu/distillation/tiny_transformers.py | 95 --- kazu/distillation/train.py | 35 - kazu/tests/test_distillation_training.py | 18 - pyproject.toml | 25 +- 14 files changed, 8 insertions(+), 1385 deletions(-) delete mode 100644 kazu/conf/DistillationTraining/default.yaml delete mode 100644 kazu/conf/DistillationTraining/intermediatelayer.yaml delete mode 100644 kazu/conf/DistillationTraining/production.yaml delete mode 100644 kazu/distillation/__init__.py delete mode 100644 kazu/distillation/dataprocessor.py delete mode 100644 kazu/distillation/lightning_plugins.py delete mode 100644 kazu/distillation/metrics.py delete mode 100644 kazu/distillation/models.py delete mode 100644 kazu/distillation/tiny_transformers.py delete mode 100644 kazu/distillation/train.py delete mode 100644 kazu/tests/test_distillation_training.py diff --git a/.github/workflows/precommit_and_docs_build_check.yml b/.github/workflows/precommit_and_docs_build_check.yml index f8951000..9778e828 100644 --- a/.github/workflows/precommit_and_docs_build_check.yml +++ b/.github/workflows/precommit_and_docs_build_check.yml @@ -21,8 +21,6 @@ jobs: - name: create venv and install dependencies run: | - sudo apt-get update - sudo apt install libopenblas-dev rm -r /tmp/kazu-env || true PIP_CACHE_DIR=$(pip cache dir) @@ -31,7 +29,7 @@ jobs: . /tmp/kazu-env/bin/activate python -m pip install --upgrade pip pip install --upgrade --upgrade-strategy eager --index-url https://download.pytorch.org/whl/cpu "torch>=2.0.0" - pip install -e ."[dev]" --no-cache-dir + pip install -e ."[dev]" --cache-dir $PIP_CACHE_DIR --upgrade --upgrade-strategy eager" - name: Check precommit run: | diff --git a/kazu/conf/DistillationTraining/default.yaml b/kazu/conf/DistillationTraining/default.yaml deleted file mode 100644 index 27c8c394..00000000 --- a/kazu/conf/DistillationTraining/default.yaml +++ /dev/null @@ -1,71 +0,0 @@ -seed: 42 -cudnn: - deterministic: True - benchmark: False -monitoring: - monitor: entity_f1 - mode: max -training_params: - max_epochs: 10 -save_dir: "." # save dir managed by hydra. Override here if required -model: - _target_: kazu.distillation.models.SequenceTaggingDistillationForFinalLayer - data_dir: ??? - label_list: - - B-disease - - B-drug - - B-gene - - B-mutation - - B-species - - I-disease - - I-drug - - I-gene - - I-mutation - - I-species - - O - student_model_path: ??? - teacher_model_path: ??? - batch_size: 8 - max_length: 128 - num_workers: 2 - temperature: 1.0 - warmup_steps: 0 - learning_rate: 5e-5 - schedule: warmup_linear - weight_decay: 0.01 - accumulate_grad_batches: ${DistillationTraining.trainer.accumulate_grad_batches} - max_epochs: ${DistillationTraining.training_params.max_epochs} - metric: ${DistillationTraining.monitoring.monitor} -trainer: - _convert_: 'partial' #needed to convert ListConfig to list for Pytorch lightning plugins parameter, which needs a list (ver 1.6.4) - _target_: pytorch_lightning.Trainer - num_sanity_val_steps: 2 - accelerator: "cpu" - val_check_interval: 1.0 - accumulate_grad_batches: 1 - max_epochs: ${DistillationTraining.training_params.max_epochs} - logger: - - _target_: pytorch_lightning.loggers.CSVLogger - save_dir: ${DistillationTraining.save_dir} - name: csv_log - plugins: - - _target_: kazu.distillation.lightning_plugins.StudentModelCheckpointIO - model_name_or_path: ${DistillationTraining.model.student_model_path} - callbacks: - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: ${DistillationTraining.save_dir} - filename: "student_model-{epoch:02d}-{entity_f1:.4f}-{validation_loss:.3f}-{step:05d}" - monitor: ${DistillationTraining.monitoring.monitor} - mode: ${DistillationTraining.monitoring.mode} - save_top_k: 5 - save_last: True - every_n_train_steps: ~ - every_n_epochs: ~ - - _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping - monitor: ${DistillationTraining.monitoring.monitor} - mode: ${DistillationTraining.monitoring.mode} - min_delta: 0.00 - patience: 5 - verbose: True - - _target_: pytorch_lightning.callbacks.progress.TQDMProgressBar - refresh_rate: 1 diff --git a/kazu/conf/DistillationTraining/intermediatelayer.yaml b/kazu/conf/DistillationTraining/intermediatelayer.yaml deleted file mode 100644 index ede32ce3..00000000 --- a/kazu/conf/DistillationTraining/intermediatelayer.yaml +++ /dev/null @@ -1,72 +0,0 @@ -seed: 42 -cudnn: - deterministic: True - benchmark: False -monitoring: - monitor: mserror - mode: min -training_params: - max_epochs: 5 -save_dir: "." # save dir managed by hydra. Override here if required -model: - _target_: kazu.distillation.models.SequenceTaggingDistillationForIntermediateLayer - data_dir: ??? - label_list: - - B-disease - - B-drug - - B-gene - - B-mutation - - B-species - - I-disease - - I-drug - - I-gene - - I-mutation - - I-species - - O - student_model_path: ??? - teacher_model_path: ??? - batch_size: 128 - max_length: 128 - num_workers: 4 - temperature: 1.0 - warmup_steps: 0 - learning_rate: 3e-5 - schedule: torchStepLR - weight_decay: 0.01 - accumulate_grad_batches: ${DistillationTraining.trainer.accumulate_grad_batches} - max_epochs: ${DistillationTraining.training_params.max_epochs} - metric: ${DistillationTraining.monitoring.monitor} -trainer: - _target_: pytorch_lightning.Trainer - num_sanity_val_steps: 2 - devices: "auto" - accelerator: "cpu" - val_check_interval: 0.05 - accumulate_grad_batches: 1 - max_epochs: ${DistillationTraining.training_params.max_epochs} - strategy: "ddp" - logger: - - _target_: pytorch_lightning.loggers.CSVLogger - save_dir: ${DistillationTraining.save_dir} - name: csv_log - plugins: - - _target_: kazu.distillation.lightning_plugins.StudentModelCheckpointIO - model_name_or_path: ${DistillationTraining.model.student_model_path} - callbacks: - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: ${DistillationTraining.save_dir} - filename: "student_model_inm_layer-{epoch:02d}-{mserror:.4f}-{step:05d}" - monitor: ${DistillationTraining.monitoring.monitor} - mode: ${DistillationTraining.monitoring.mode} - save_top_k: 5 - save_last: True - every_n_train_steps: ~ - every_n_epochs: ~ - - _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping - monitor: ${DistillationTraining.monitoring.monitor} - mode: ${DistillationTraining.monitoring.mode} - min_delta: 0.00 - patience: 10 - verbose: True - - _target_: pytorch_lightning.callbacks.progress.TQDMProgressBar - refresh_rate: 1 diff --git a/kazu/conf/DistillationTraining/production.yaml b/kazu/conf/DistillationTraining/production.yaml deleted file mode 100644 index f46c32f4..00000000 --- a/kazu/conf/DistillationTraining/production.yaml +++ /dev/null @@ -1,72 +0,0 @@ -seed: 42 -cudnn: - deterministic: True - benchmark: False -monitoring: - monitor: entity_f1 - mode: max -training_params: - max_epochs: 5 -save_dir: "." # save dir managed by hydra. Override here if required -model: - _target_: kazu.distillation.models.SequenceTaggingDistillationForFinalLayer - data_dir: ??? - label_list: - - B-disease - - B-drug - - B-gene - - B-mutation - - B-species - - I-disease - - I-drug - - I-gene - - I-mutation - - I-species - - O - student_model_path: ??? - teacher_model_path: ??? - batch_size: 128 - max_length: 128 - num_workers: 4 - temperature: 1.0 - warmup_steps: 0 - learning_rate: 3e-5 - schedule: torchStepLR - weight_decay: 0.01 - accumulate_grad_batches: ${DistillationTraining.trainer.accumulate_grad_batches} - max_epochs: ${DistillationTraining.training_params.max_epochs} - metric: ${DistillationTraining.monitoring.monitor} -trainer: - _target_: pytorch_lightning.Trainer - num_sanity_val_steps: 2 - devices: "auto" - accelerator: "cpu" - val_check_interval: 0.05 - accumulate_grad_batches: 1 - max_epochs: ${DistillationTraining.training_params.max_epochs} - strategy: "ddp" - logger: - - _target_: pytorch_lightning.loggers.CSVLogger - save_dir: ${DistillationTraining.save_dir} - name: csv_log - plugins: - - _target_: kazu.distillation.lightning_plugins.StudentModelCheckpointIO - model_name_or_path: ${DistillationTraining.model.student_model_path} - callbacks: - - _target_: pytorch_lightning.callbacks.ModelCheckpoint - dirpath: ${DistillationTraining.save_dir} - filename: "student_model-{epoch:02d}-{entity_f1:.4f}-{validation_loss:.3f}-{step:05d}" - monitor: ${DistillationTraining.monitoring.monitor} - mode: ${DistillationTraining.monitoring.mode} - save_top_k: 5 - save_last: True - every_n_train_steps: ~ - every_n_epochs: ~ - - _target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping - monitor: ${DistillationTraining.monitoring.monitor} - mode: ${DistillationTraining.monitoring.mode} - min_delta: 0.00 - patience: 10 - verbose: True - - _target_: pytorch_lightning.callbacks.progress.TQDMProgressBar - refresh_rate: 1 diff --git a/kazu/conf/config.yaml b/kazu/conf/config.yaml index e202a256..50f05a06 100644 --- a/kazu/conf/config.yaml +++ b/kazu/conf/config.yaml @@ -6,7 +6,7 @@ defaults: - ot_molecule - ot_bio_proc - ot_phenotype -# - ot_measurement # curations are currently low quality + # - ot_measurement # curations are currently low quality - ot_medical_proc - cellosaurus - chembl @@ -28,7 +28,6 @@ defaults: - DictionaryEntityLinkingStep: default - SapBertTraining: default - Pipeline: default - - DistillationTraining: default - OpsinStep: default - SethStep: default - SynonymGeneration: default @@ -50,7 +49,7 @@ defaults: - global_actions: default - GLiNERStep: default - autocurator: default - - _self_ # see https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order/ + - _self_ # see https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/default_composition_order/ # we set certain env vars here for things that are statically initialised hydra: diff --git a/kazu/distillation/__init__.py b/kazu/distillation/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kazu/distillation/dataprocessor.py b/kazu/distillation/dataprocessor.py deleted file mode 100644 index 31da8471..00000000 --- a/kazu/distillation/dataprocessor.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Modified for distillation using Pytorch Lightning by KAZU team. - -Based off of the `DataProcessor `_ -and `NerProcessor `_ classes in BioBERT, -also written in reference to Huawei Noah's Ark Lab `TinyBERT `_. - -Licensed under Apache 2.0 - -| Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team and Huawei Noah's Ark Lab. -| Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. - -.. raw:: html - -

- Full License Notice - -| Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team and Huawei Noah's Ark Lab. -| Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -.. raw:: html - -
-""" - -import logging -import os -from collections.abc import Iterable - -from transformers import InputExample, DataProcessor - -from kazu.utils.utils import PathLike - -logger = logging.getLogger(__name__) - - -# type ignore is necessary because transformers doesn't distribute type hints, -# so DataProcessor is seen as 'Any'. -class SeqTagProcessor(DataProcessor): # type: ignore[misc] - """Base class for data converters for sequence tagging data sets.""" - - def get_train_examples(self, data_dir: str) -> list[InputExample]: - """Gets a collection of :class:`transformers.InputExample` for the train set.""" - raise NotImplementedError - - def get_dev_examples(self, data_dir: str) -> list[InputExample]: - """Gets a collection of :class:`transformers.InputExample` for the dev set.""" - raise NotImplementedError - - def get_aug_examples(self, data_dir: str) -> list[InputExample]: - """Gets a collection of :class:`transformers.InputExample` for the aug set.""" - raise NotImplementedError - - -class NerProcessor(SeqTagProcessor): - def get_train_examples(self, data_dir: str) -> list[InputExample]: - return self._create_examples( - self._read_data(os.path.join(data_dir, "train_dev.tsv")), "train" - ) - - def get_dev_examples(self, data_dir: str) -> list[InputExample]: - return self._create_examples(self._read_data(os.path.join(data_dir, "devel.tsv")), "dev") - - def get_test_examples(self, data_dir: str) -> list[InputExample]: - """Gets a collection of :class:`transformers.InputExample` for the test set.""" - return self._create_examples(self._read_data(os.path.join(data_dir, "test.tsv")), "test") - - def get_aug_examples(self, data_dir: str) -> list[InputExample]: - return self._create_examples(self._read_tsv(os.path.join(data_dir, "train_aug.tsv")), "aug") - - def _create_examples( - self, lines: Iterable[tuple[str, str]], set_type: str - ) -> list[InputExample]: - examples = [] - for (i, line) in enumerate(lines): - guid = "%s-%s" % (set_type, i) - # TODO assert if tokenization from BERT and tokenization from pytorch mismatch - text = line[1] - label = line[0] - examples.append(InputExample(guid=guid, text_a=text, label=label)) - return examples - - @classmethod - def _read_data(cls, input_file: PathLike) -> list[tuple[str, str]]: - """Reads a BIO data.""" - with open(input_file) as inpFilept: - lines = [] - words: list[str] = [] - labels: list[str] = [] - continualLineErrorCnt = 0 - for lineIdx, line in enumerate(inpFilept): - contents = line.splitlines()[0] - lineList = contents.split() - if len(lineList) == 0: # For blank line - assert len(words) == len( - labels - ), "lineIdx: %s, len(words)(%s) != len(labels)(%s) \n %s\n%s" % ( - lineIdx, - len(words), - len(labels), - " ".join(words), - " ".join(labels), - ) - if len(words) != 0: - wordSent = " ".join(words) - labelSent = " ".join(labels) - lines.append((labelSent, wordSent)) - words = [] - labels = [] - else: - continualLineErrorCnt += 1 - else: - words.append(lineList[0]) - labels.append(lineList[-1]) - if len(words) != 0: - wordSent = " ".join(words) - labelSent = " ".join(labels) - lines.append((labelSent, wordSent)) - - logging.info("continualLineErrorCnt : %s" % (continualLineErrorCnt)) - return lines diff --git a/kazu/distillation/lightning_plugins.py b/kazu/distillation/lightning_plugins.py deleted file mode 100644 index 1eadb6ef..00000000 --- a/kazu/distillation/lightning_plugins.py +++ /dev/null @@ -1,73 +0,0 @@ -import os -from pathlib import Path -from typing import Any, Optional, Union - -import torch -from omegaconf import OmegaConf -from pytorch_lightning.plugins import CheckpointIO -from transformers import AutoTokenizer - - -# lightning doesn't distribute type information, so to mypy -# this is subclassing 'Any'. -class StudentModelCheckpointIO(CheckpointIO): # type: ignore[misc] - """A plugin for saving student model (without saving teacher model)""" - - def __init__(self, model_name_or_path: str): - super().__init__() - self.model_name_or_path = model_name_or_path - - def save_checkpoint( - self, - checkpoint: dict[str, Any], - path: Union[str, Path], - storage_options: Optional[Any] = None, - ) -> None: - """Save distilled (student) model. Loading currently not implemented. - - :param checkpoint: contents to save. Including ``state_dict``, ``optimizer_states`` and ``callbacks``. - :param path: - :param storage_options: - """ - - dirPath = os.path.dirname(path) - - output_config_file = os.path.join(dirPath, "hyper_parameters.json") - OmegaConf.save( - OmegaConf.create(checkpoint["hyper_parameters"]), output_config_file, resolve=True - ) - studentModel_state_dict = { - key[len("student_model.") :]: value - for key, value in checkpoint["state_dict"].items() - if key.startswith("student_model.") - } - teacherModel_state_dict = { - key: value - for key, value in checkpoint["state_dict"].items() - if key.startswith("teacher_model.") - } - assert len(checkpoint["state_dict"]) == len(studentModel_state_dict) + len( - teacherModel_state_dict - ), "Missing structures while saving trained model." - torch.save(studentModel_state_dict, path) - - AutoTokenizer.from_pretrained(self.model_name_or_path).save_vocabulary( - os.path.dirname(path) - ) - - def load_checkpoint( - self, path: Union[str, Path], storage_options: Optional[Any] = None - ) -> dict[str, Any]: - """Not currently implemented. - - See :external+pytorch_lightning:meth:`CheckpointIO.load_checkpoint ` - for details of the abstract method. - """ - raise NotImplementedError - - def remove_checkpoint( - self, - path: Union[str, Path], - ) -> None: - - os.remove(path) diff --git a/kazu/distillation/metrics.py b/kazu/distillation/metrics.py deleted file mode 100644 index 00d83583..00000000 --- a/kazu/distillation/metrics.py +++ /dev/null @@ -1,44 +0,0 @@ -try: - from seqeval.metrics import f1_score -except ImportError as e: - raise ImportError( - "Running the model distillation code requires seqeval to be installed.\n" - "We recommend running 'pip install kazu[model_training]' to get all model training" - " dependencies." - ) from e - - -IGNORE_IDX = -100 - - -def accuracy(preds, labels): - return (preds == labels).mean() - - -def numeric_label_f1_score( - preds: list[list[int]], golds: list[list[int]], label_list: list[str] -) -> float: - """Function to calculate F1 score using seqeval and numerical format labels. - - :param preds: 2d array of predicted label ids - :param golds: 2d array of gold standard ids - :param label_list: list of strings, for mappingids to labels - :return: - """ - - pred_clean_labels_list = [] - gold_clean_labels_list = [] - - assert len(preds) == len(golds) - for preds_id_sequence, golds_id_sequence in zip(preds, golds): - assert len(preds_id_sequence) == len(golds_id_sequence) - p_labels = [] - g_labels = [] - for pred_label, gold_label in zip(preds_id_sequence, golds_id_sequence): - if gold_label != IGNORE_IDX: - p_labels.append(label_list[pred_label]) - g_labels.append(label_list[gold_label]) - pred_clean_labels_list.append(p_labels) - gold_clean_labels_list.append(g_labels) - f1: float = f1_score(gold_clean_labels_list, pred_clean_labels_list) - return f1 diff --git a/kazu/distillation/models.py b/kazu/distillation/models.py deleted file mode 100644 index eaaa67a4..00000000 --- a/kazu/distillation/models.py +++ /dev/null @@ -1,747 +0,0 @@ -"""Influenced by Huawei Noah's Ark Lab `TinyBERT `_, but heavily modified -structurally to fit in our PyTorch Lightning training setup. - -`This section of the TinyBERT code `_ -in particular is relevant. - -Licensed under Apache 2.0 - -| Copyright 2020 Huawei Technologies Co., Ltd. -| Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team. -| Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. - -.. raw:: html - -
- Full License Notice - -| Copyright 2020 Huawei Technologies Co., Ltd. -| Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team. -| Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -.. raw:: html - -
-""" - -import logging -from typing import Union, Optional, Any, cast -from collections.abc import Callable - -import numpy as np -import pytorch_lightning as pl -import torch -from cachetools import LRUCache -from omegaconf import ListConfig, OmegaConf -from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS -from torch.nn import CrossEntropyLoss, MSELoss -from torch.optim.lr_scheduler import LRScheduler -from torch.optim import AdamW -from torch.utils.data import Dataset, DataLoader -from transformers import ( - AutoTokenizer, - InputExample, - DataCollatorForTokenClassification, - PreTrainedTokenizer, - PreTrainedTokenizerFast, - get_constant_schedule, - get_linear_schedule_with_warmup, - get_cosine_schedule_with_warmup, - get_constant_schedule_with_warmup, -) -from transformers.utils import check_min_version - -from kazu.distillation.dataprocessor import NerProcessor -from kazu.distillation.metrics import numeric_label_f1_score, IGNORE_IDX -from kazu.distillation.tiny_transformers import TinyBertForSequenceTagging - -check_min_version("4.0.0") # at least 4.0.0... for optimerzers - -logger = logging.getLogger(__name__) - -SCHEDULES: dict[Optional[str], Callable] = { - None: get_constant_schedule, - "none": get_constant_schedule, - "warmup_cosine": get_cosine_schedule_with_warmup, - "warmup_constant": get_constant_schedule_with_warmup, - "warmup_linear": get_linear_schedule_with_warmup, - "torchStepLR": torch.optim.lr_scheduler.StepLR, -} - - -class NerDataset(Dataset): - """A dataset used for Ner. - - designed for on the fly tokenisation to speed up multi processing. Uses caching to - prevent repeated processing - """ - - def __init__( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - examples: list[InputExample], - label_map: dict[str, int], - max_length: int, - ): - """ - - :param tokenizer: typically created from AutoTokenizer.from_pretrained - :param examples: a list of InputExample, typically created from a - :class:`kazu.distillation.dataprocessor.NerProcessor` - :param label_map: str to int mapping of labels - :param max_length: The maximum number of tokens per instance that the model can handle. - Inputs longer than max_length value will be truncated. - """ - self.label_map = label_map - self.examples = examples - self.tokenizer = tokenizer - self.max_length = max_length - self.call_count = 0 - self.cache: LRUCache[int, dict[str, list]] = LRUCache(5000) - - def __getitem__(self, index: int) -> dict[str, list]: - if index not in self.cache: - self.cache[index] = self.convert_single_example( - ex_index=index, example=self.examples[index] - ) - self.call_count += 1 - return self.cache[index] - - def __len__(self): - return len(self.examples) - - def convert_single_example(self, ex_index: int, example: InputExample) -> dict[str, list]: - textlist = example.text_a.split() - labellist = example.label.split() - tokens: list[str] = [] - labels: list[str] = [] - for i, word in enumerate(textlist): - tokenized = self.tokenizer.tokenize(word) - tokens.extend(tokenized) - label_1 = labellist[i] - for m, tok in enumerate(tokenized): - if m == 0: - labels.append(label_1) - else: - labels.append("X") - - ntokens: list[str] = [] - segment_ids: list[int] = [] - label_id: list[int] = [] - ntokens.append("[CLS]") - segment_ids.append(0) - label_id.append(IGNORE_IDX) - for i, token in enumerate(tokens): - ntokens.append(token) - segment_ids.append(0) - if labels[i] == "X": - label_id.append(IGNORE_IDX) - else: - label_id.append(self.label_map[labels[i]]) - - # Truncation - if len(ntokens) > self.max_length - 1: - assert (len(ntokens) == len(segment_ids)) and (len(ntokens) == len(label_id)) - ntokens = ntokens[: self.max_length - 1] - segment_ids = segment_ids[: self.max_length - 1] - label_id = label_id[: self.max_length - 1] - - ntokens.append("[SEP]") - segment_ids.append(0) - label_id.append(IGNORE_IDX) - - input_ids = self.tokenizer.convert_tokens_to_ids(ntokens) - # convert_tokens_to_ids can return a single int if a single token is passed, - # but we know here we're dealing with lists. - assert isinstance(input_ids, list) - # The mask has 1 for real tokens and 0 for padding tokens. - input_mask = [1] * len(input_ids) - if self.call_count < 4 and ex_index < 4: # Examples. Executed only once per model run - logger.info("*** Example ***") - logger.info("guid: %s" % (example.guid)) - logger.info("tokens: %s" % " ".join(tokens)) - logger.info("ntokens: %s" % " ".join(ntokens)) - logger.info("input_ids: %s" % " ".join(str(x) for x in input_ids)) - logger.info("input_mask: %s" % " ".join(str(x) for x in input_mask)) - logger.info("segment_ids: %s" % " ".join(str(x) for x in segment_ids)) - logger.info("label_id: %s" % " ".join(str(x) for x in label_id)) - - result = { - "labels": label_id, - "input_ids": input_ids, - "token_type_ids": segment_ids, - "attention_mask": input_mask, - } - return result - - -class TaskSpecificDistillation(pl.LightningModule): - def __init__( - self, - temperature: float, - warmup_steps: int, - learning_rate: float, - weight_decay: float, - batch_size: int, - accumulate_grad_batches: int, - max_epochs: int, - schedule: Optional[str] = None, - ): - """Base class for distillation on PyTorch Lightning platform. - - :param temperature: - :param warmup_steps: - :param learning_rate: - :param weight_decay: - :param batch_size: - :param accumulate_grad_batches: - :param max_epochs: - :param schedule: - """ - - super().__init__() - self.accumulate_grad_batches = accumulate_grad_batches - self.batch_size = batch_size - self.weight_decay = weight_decay - self.learning_rate = learning_rate - - self.temperature = temperature - self.schedule = schedule - self.training_examples = self.get_training_examples() - - self.num_train_optimization_steps = int( - len(self.training_examples) - / self.batch_size - / self.accumulate_grad_batches - * max_epochs - ) - self.warmup_steps = warmup_steps - logger.info( - "num_train_optimization_steps: {}, args.warmup_steps: {}, args.gradient_accumulation_steps: {}".format( - self.num_train_optimization_steps, self.warmup_steps, self.accumulate_grad_batches - ) - ) - - def get_training_examples(self) -> list[InputExample]: - """Subclasses should implement this. - - :return: - """ - raise NotImplementedError - - def get_optimizer_grouped_parameters(self, student_model): - param_optimizer = list(student_model.named_parameters()) - size = 0 - logger.info("student_model.named_parameters :") - for n, p in student_model.named_parameters(): - logger.info("n: {}".format(n)) - size += p.nelement() - logger.info("Total parameters: {}".format(size)) - - no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], - "weight_decay": self.weight_decay, - }, - { - "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - return optimizer_grouped_parameters - - def configure_optimizers( - self, - ): - """Configure optimizer and learning rate scheduler.""" - optimizer_grouped_parameters = self.get_optimizer_grouped_parameters( - student_model=self.student_model - ) - optimizer = AdamW(params=optimizer_grouped_parameters, lr=self.learning_rate) - - if self.schedule in ["torchStepLR"]: # Multi-GPU : must use torch scheduler - scheduler = SCHEDULES[self.schedule]( - optimizer, step_size=self.num_train_optimization_steps - ) # PyTorch scheduler - else: - scheduler = SCHEDULES[self.schedule]( - optimizer, - num_warmup_steps=self.warmup_steps, - num_training_steps=self.num_train_optimization_steps, - ) # transformers scheduler - - lr_scheduler_config = {"scheduler": scheduler, "interval": "step", "frequency": 1} - return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} - - -class SequenceTaggingDistillationBase(TaskSpecificDistillation): - def __init__( - self, - temperature: float, - warmup_steps: int, - learning_rate: float, - weight_decay: float, - batch_size: int, - accumulate_grad_batches: int, - max_epochs: int, - max_length: int, - data_dir: str, - label_list: Union[list, ListConfig], - student_model_path: str, - teacher_model_path: str, - num_workers: int, - schedule: Optional[str] = None, - metric: str = "Default", - ): - """Base class for sequence tagging (task-specific) distillation steps. - - :param temperature: - :param warmup_steps: - :param learning_rate: - :param weight_decay: - :param batch_size: - :param accumulate_grad_batches: - :param max_epochs: - :param max_length: - :param data_dir: - :param label_list: - :param student_model_path: - :param teacher_model_path: - :param num_workers: - :param schedule: - :param metric: - """ - self.processor = NerProcessor() - self.data_dir = data_dir - self.max_length = max_length - self.tokenizer: Union[ - PreTrainedTokenizer, PreTrainedTokenizerFast - ] = AutoTokenizer.from_pretrained(student_model_path) - super().__init__( - schedule=schedule, - accumulate_grad_batches=accumulate_grad_batches, - temperature=temperature, - warmup_steps=warmup_steps, - weight_decay=weight_decay, - learning_rate=learning_rate, - max_epochs=max_epochs, - batch_size=batch_size, - ) - - self.num_workers = num_workers - self.label_list: list[str] - if isinstance(label_list, ListConfig): - self.label_list = cast(list[str], OmegaConf.to_container(label_list)) - else: - self.label_list = label_list - self.num_labels = len(label_list) - self.label_map = {label: i for i, label in enumerate(label_list)} - self.metric = metric - - self.teacher_model = TinyBertForSequenceTagging.from_pretrained( - teacher_model_path, num_labels=self.num_labels - ) - self.student_model = TinyBertForSequenceTagging.from_pretrained( - student_model_path, num_labels=self.num_labels - ) - - def train_dataloader(self) -> TRAIN_DATALOADERS: - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.train_dataloader - `\\ .""" - dataset = NerDataset( - tokenizer=self.tokenizer, - examples=self.training_examples, - label_map=self.label_map, - max_length=self.max_length, - ) - collator = DataCollatorForTokenClassification(tokenizer=self.tokenizer, padding=True) - return DataLoader( - dataset=dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, - collate_fn=collator, - pin_memory=True, - persistent_workers=True, - ) - - def get_training_examples(self) -> list[InputExample]: - return self.processor.get_train_examples(self.data_dir) - - def val_dataloader(self) -> EVAL_DATALOADERS: - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.val_dataloader - `\\ .""" - examples = self.processor.get_dev_examples(self.data_dir) - dataset = NerDataset( - tokenizer=self.tokenizer, - examples=examples, - label_map=self.label_map, - max_length=self.max_length, - ) - collator = DataCollatorForTokenClassification(tokenizer=self.tokenizer, padding=True) - return DataLoader( - dataset=dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - collate_fn=collator, - pin_memory=True, - persistent_workers=True, - ) - - -class SequenceTaggingDistillationForFinalLayer(SequenceTaggingDistillationBase): - def __init__( - self, - temperature: float, - warmup_steps: int, - learning_rate: float, - weight_decay: float, - batch_size: int, - accumulate_grad_batches: int, - max_epochs: int, - max_length: int, - data_dir: str, - label_list: Union[list, ListConfig], - student_model_path: str, - teacher_model_path: str, - num_workers: int, - schedule: Optional[str] = None, - metric: str = "Default", - ): - """A class for sequence tagging (task-specific) final-layer distillation step. - - :param temperature: - :param warmup_steps: - :param learning_rate: - :param weight_decay: - :param batch_size: - :param accumulate_grad_batches: - :param max_epochs: - :param max_length: - :param data_dir: - :param label_list: - :param student_model_path: - :param teacher_model_path: - :param num_workers: - :param schedule: - :param metric: - """ - super().__init__( - temperature=temperature, - warmup_steps=warmup_steps, - weight_decay=weight_decay, - learning_rate=learning_rate, - batch_size=batch_size, - accumulate_grad_batches=accumulate_grad_batches, - max_epochs=max_epochs, - max_length=max_length, - data_dir=data_dir, - label_list=label_list, - student_model_path=student_model_path, - teacher_model_path=teacher_model_path, - num_workers=num_workers, - schedule=schedule, - metric=metric, - ) - # Loss function: self.soft_cross_entropy for training, CrossEntropyLoss for validation - self.loss = CrossEntropyLoss(ignore_index=IGNORE_IDX) - self.save_hyperparameters() - - def soft_cross_entropy(self, predicts, targets): - student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1) - targets_prob = torch.nn.functional.softmax(targets, dim=-1) - return (-targets_prob * student_likelihood).mean() - - def training_step(self, batch, batch_idx): - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.training_step - `\\ .""" - student_logits, student_atts, student_reps = self.student_model( - input_ids=batch["input_ids"], - token_type_ids=batch["token_type_ids"], - attention_mask=batch["attention_mask"], - is_student=True, - ) - - # self.teacher_model.eval() - with torch.no_grad(): - teacher_logits, teacher_atts, teacher_reps = self.teacher_model( - input_ids=batch["input_ids"], - token_type_ids=batch["token_type_ids"], - attention_mask=batch["attention_mask"], - ) - - loss = self.soft_cross_entropy( - student_logits / self.temperature, teacher_logits / self.temperature - ) - - # Logging - self.log("training_loss", loss, prog_bar=True, on_step=True) - scheduler = self.lr_schedulers() - assert isinstance(scheduler, LRScheduler) - lr_list = scheduler.get_last_lr() - self.log("lr", lr_list[0], prog_bar=True, on_step=True) - if lr_list[0] != lr_list[1]: - self.log("lr1", lr_list[1], prog_bar=True, on_step=True) - - return loss - - def validation_step(self, batch, batch_idx): - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.validation_step - `\\ .""" - logits, _, _ = self.student_model( - input_ids=batch["input_ids"], - token_type_ids=batch["token_type_ids"], - attention_mask=batch["attention_mask"], - is_student=True, - ) - - loss = ( - self.loss( - logits.view(-1, self.num_labels), - batch["labels"].view(-1), - ) - .mean() - .item() - ) - return { - "loss": loss, - "logits": logits.detach().cpu(), - "label_ids": batch["labels"].detach().cpu(), - "attention_mask": batch["attention_mask"].detach().cpu(), - } - - def validation_epoch_end(self, val_step_outputs): - """Implementation of :meth:`LightningModule.validation_epoch_end - `\\ .""" - epoch_loss_mean = np.mean([x["loss"] for x in val_step_outputs]) - self.log( - "validation_loss", - epoch_loss_mean, - prog_bar=True, - on_step=False, - on_epoch=True, - sync_dist=True, - ) - - preds, golds = [], [] - for output in val_step_outputs: - - attention_mask = output["attention_mask"] - golds.extend(self.tensor_to_jagged_array(output["label_ids"], attention_mask)) - preds.extend( - self.tensor_to_jagged_array(torch.argmax(output["logits"], dim=-1), attention_mask) - ) - - result = numeric_label_f1_score(preds=preds, golds=golds, label_list=self.label_list) - self.log( - self.metric, result, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True - ) # Micro F1 - - def tensor_to_jagged_array( - self, tensor: torch.Tensor, attention_mask: torch.Tensor - ) -> list[list[int]]: - result = [] - for arr, mask in zip(tensor.numpy(), attention_mask.numpy()): - result.append(arr[0 : mask.sum()].tolist()) - return result - - -class SequenceTaggingDistillationForIntermediateLayer(SequenceTaggingDistillationBase): - def __init__( - self, - temperature: float, - warmup_steps: int, - learning_rate: float, - weight_decay: float, - batch_size: int, - accumulate_grad_batches: int, - max_epochs: int, - max_length: int, - data_dir: str, - label_list: Union[list, ListConfig], - student_model_path: str, - teacher_model_path: str, - num_workers: int, - schedule: Optional[str] = None, - metric: str = "Default", - ): - """A class for sequence tagging (task-specific) intermediate-layer (Transformer, - Embedding) distillation step. - - :param temperature: - :param warmup_steps: - :param learning_rate: - :param weight_decay: - :param batch_size: - :param accumulate_grad_batches: - :param max_epochs: - :param max_length: - :param data_dir: - :param label_list: - :param student_model_path: - :param teacher_model_path: - :param num_workers: - :param schedule: - :param metric: - """ - super().__init__( - temperature=temperature, - warmup_steps=warmup_steps, - weight_decay=weight_decay, - learning_rate=learning_rate, - batch_size=batch_size, - accumulate_grad_batches=accumulate_grad_batches, - max_epochs=max_epochs, - max_length=max_length, - data_dir=data_dir, - label_list=label_list, - student_model_path=student_model_path, - teacher_model_path=teacher_model_path, - num_workers=num_workers, - schedule=schedule, - metric=metric, - ) - - self.loss = MSELoss() - self.save_hyperparameters() - - def _run_step(self, batch: Any) -> tuple[torch.Tensor, torch.Tensor]: - """Function for the training/validation step. Computes attention based - distillation loss and hidden states based distillation loss. - - :param batch: The output of DataLoader. - :return: A tuple of tensors. (rep_loss, att_loss) - - rep_loss: hidden states based distillation loss (includes embedding-layer distillation) - att_loss: attention based distillation loss - """ - student_logits, student_atts, student_reps = self.student_model( - input_ids=batch["input_ids"], - token_type_ids=batch["token_type_ids"], - attention_mask=batch["attention_mask"], - is_student=True, - ) - - # self.teacher_model.eval() - with torch.no_grad(): - teacher_logits, teacher_atts, teacher_reps = self.teacher_model( - input_ids=batch["input_ids"], - token_type_ids=batch["token_type_ids"], - attention_mask=batch["attention_mask"], - ) - - teacher_layer_num = len(teacher_atts) - student_layer_num = len(student_atts) - assert teacher_layer_num % student_layer_num == 0 - layers_per_block = int(teacher_layer_num / student_layer_num) - - att_loss = torch.Tensor([0.0]) - rep_loss = torch.Tensor([0.0]) - - new_teacher_atts = [ - teacher_atts[i * layers_per_block + layers_per_block - 1] - for i in range(student_layer_num) - ] - - for student_att, teacher_att in zip(student_atts, new_teacher_atts): - student_att_cond = torch.where( - student_att <= -1e2, torch.zeros_like(student_att), student_att - ) - teacher_att_cond = torch.where( - teacher_att <= -1e2, torch.zeros_like(teacher_att), teacher_att - ) - - att_loss += self.loss(student_att_cond, teacher_att_cond) - - new_teacher_reps = [ - teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1) - ] - - for student_rep, teacher_rep in zip(student_reps, new_teacher_reps): - rep_loss += self.loss(student_rep, teacher_rep) - - return rep_loss, att_loss - - def training_step(self, batch, batch_idx): - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.training_step - `\\ .""" - rep_loss, att_loss = self._run_step(batch) - loss = rep_loss + att_loss - - # Logging - self.log("training_loss", loss, on_step=True) - self.log("att_loss", att_loss, on_step=True) - self.log("rep_loss", rep_loss, on_step=True) - scheduler = self.lr_schedulers() - assert isinstance(scheduler, LRScheduler) - lr_list = scheduler.get_last_lr() - self.log("lr", lr_list[0], prog_bar=True, on_step=True) - if lr_list[0] != lr_list[1]: - self.log("lr1", lr_list[1], prog_bar=True, on_step=True) - - return loss - - def validation_step(self, batch, batch_idx): - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.validation_step - `\\ .""" - rep_loss, att_loss = self._run_step(batch) - loss = rep_loss + att_loss - return { - "loss": loss.detach().cpu(), - "rep_loss": rep_loss.detach().cpu(), - "att_loss": att_loss.detach().cpu(), - } - - def validation_epoch_end(self, val_step_outputs): - """Implementation of :meth:`LightningModule.validation_epoch_end - `\\ .""" - - epoch_loss_mean = np.mean([x["loss"] for x in val_step_outputs]) - self.log( - self.metric, - epoch_loss_mean, - prog_bar=True, - on_step=False, - on_epoch=True, - sync_dist=True, - ) - - epoch_rep_loss_mean = np.mean([x["rep_loss"] for x in val_step_outputs]) - self.log( - "val_rep_loss", - epoch_rep_loss_mean, - prog_bar=True, - on_step=False, - on_epoch=True, - sync_dist=True, - ) - epoch_att_loss_mean = np.mean([x["att_loss"] for x in val_step_outputs]) - self.log( - "val_att_loss", - epoch_att_loss_mean, - prog_bar=True, - on_step=False, - on_epoch=True, - sync_dist=True, - ) diff --git a/kazu/distillation/tiny_transformers.py b/kazu/distillation/tiny_transformers.py deleted file mode 100644 index f125cc20..00000000 --- a/kazu/distillation/tiny_transformers.py +++ /dev/null @@ -1,95 +0,0 @@ -from torch import nn - -from transformers import BertModel, BertPreTrainedModel - - -# ignore required because transformers doesn't distribute type information -# so to mypy, this is subclassing 'Any'. -class TinyBertForSequenceTagging(BertPreTrainedModel): # type: ignore[misc] - """PyTorch BERT model for sequence tagging. - - Based off `TinyBERT from Huawei Noah's Ark Lab `_ - - the `TinyBertForSequenceClassification `_ - class specifically. - - Modified for distillation using Pytorch Lightning by KAZU team. - - Originally Licensed under Apache 2.0 - - | Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team., and KAZU Team - | Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. - | Copyright 2020 Huawei Technologies Co., Ltd - - .. raw:: html - -
- Full License Notice - - | Copyright 2018 The Google AI Language Team Authors, The HuggingFace Inc. team., and KAZU Team - | Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. - | Copyright 2020 Huawei Technologies Co., Ltd - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - .. raw:: html - -
- """ - - def __init__(self, config, num_labels=None, fit_size=768): - super(TinyBertForSequenceTagging, self).__init__(config) - - if num_labels is None: - self.num_labels = config.num_labels - else: - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(in_features=config.hidden_size, out_features=self.num_labels) - self.fit_dense = nn.Linear(config.hidden_size, fit_size) - self.init_weights() - - def forward( - self, input_ids, token_type_ids=None, attention_mask=None, labels=None, is_student=False - ): - """Defines the computation performed when the model is called. - - Note that users should call the - :class:`TinyBertForSequenceTagging` instance itself, rather than - this method directly, because calling the instance runs - registered 'hooks' on the instance. - - This works as this class inherits (through its base class) from - :class:`torch.nn.Module`\\ , which defines __call__ to call the - forward method, as well as registered hooks. - """ - - output = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - output_hidden_states=True, - output_attentions=True, - return_dict=True, - ) - sequence_output = output["hidden_states"] - # sequence_output [list of torch tensor] = (number of layers + 1) * [batch_size, sequence_length, hidden_size] - att_output = output["attentions"] - logits = self.classifier(sequence_output[-1]) - - tmp = [] - if is_student: - for sequence_layer in sequence_output: - tmp.append(self.fit_dense(sequence_layer)) - sequence_output = tmp - return logits, att_output, sequence_output diff --git a/kazu/distillation/train.py b/kazu/distillation/train.py deleted file mode 100644 index 90b9fb75..00000000 --- a/kazu/distillation/train.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Union - -import hydra -import pytorch_lightning -import torch -from hydra.utils import instantiate -from omegaconf import DictConfig -from pytorch_lightning import Trainer - -from kazu.distillation.models import ( - SequenceTaggingDistillationForFinalLayer, - SequenceTaggingDistillationForIntermediateLayer, -) -from kazu.utils.constants import HYDRA_VERSION_BASE - - -@hydra.main(version_base=HYDRA_VERSION_BASE, config_path="../../conf", config_name="config") -def start(cfg: DictConfig) -> None: - - if torch.backends.cudnn.is_available(): # for reproducibility - torch.backends.cudnn.deterministic = cfg.DistillationTraining.cudnn.deterministic - torch.backends.cudnn.benchmark = ( - cfg.DistillationTraining.cudnn.benchmark # set true if the model is not for research; True will make training faster - ) - - pytorch_lightning.seed_everything(cfg.DistillationTraining.seed) - trainer: Trainer = instantiate(cfg.DistillationTraining.trainer) - model: Union[ - SequenceTaggingDistillationForFinalLayer, SequenceTaggingDistillationForIntermediateLayer - ] = instantiate(cfg.DistillationTraining.model) - trainer.fit(model) - - -if __name__ == "__main__": - start() diff --git a/kazu/tests/test_distillation_training.py b/kazu/tests/test_distillation_training.py deleted file mode 100644 index 7853730e..00000000 --- a/kazu/tests/test_distillation_training.py +++ /dev/null @@ -1,18 +0,0 @@ -from kazu.distillation.train import start -from kazu.tests.utils import TEST_ASSETS_PATH, BERT_TEST_MODEL_PATH - -DATA_DIR = TEST_ASSETS_PATH.joinpath("tinybern") - - -def test_stage_2_tinybert_distillation(tmp_path, override_kazu_test_config): - cfg = override_kazu_test_config( - overrides=[ - f"DistillationTraining.model.student_model_path={BERT_TEST_MODEL_PATH}", - f"DistillationTraining.model.teacher_model_path={BERT_TEST_MODEL_PATH}", - f"DistillationTraining.model.data_dir={DATA_DIR}", - f"DistillationTraining.save_dir={tmp_path}", - "DistillationTraining.training_params.max_epochs=2", - ], - ) - - start(cfg) diff --git a/pyproject.toml b/pyproject.toml index 5e0c27ab..6c78604d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,8 +63,7 @@ webserver = [ "PyJWT>=2.0.0", ] llm = [ - "google-cloud-aiplatform", # Required for vertex API - "openai", # Required for OpenAI API + "openai", # Required for OpenAI API ] typed = [ # version 2.31.0.3 introduced @@ -77,16 +76,6 @@ typed = [ "pandas-stubs>=2.0.0", "types-tqdm", ] -model-training = [ - "pytorch-metric-learning>=0.9.99", - "seqeval>=1.0.0", - "pytorch-lightning>=1.7.4,<2.0.0", - # version 0.10.0 causes the model - # inference to fail, and lightning - # itself doesn't restrict strictly - # enough - "lightning-utilities==0.9.0", -] all-steps = [ "py4j>=0.10.9", "rdkit>=2023.3.1", @@ -94,7 +83,7 @@ all-steps = [ "gliner==0.1.7", # GLINER is under active development, so pinning version until API stabilises ] dev = [ - "kazu[webserver,typed,model-training,all-steps,model-training,llm]", + "kazu[webserver,typed,all-steps,llm]", "black~=22.0", "blacken-docs", "flake8", @@ -107,7 +96,7 @@ dev = [ "pytest-cov", "pytest-timeout", "hypothesis", - "sphinx>=7.2,<8.0", # 8.0.0 breaks our docs at the moment + "sphinx>=7.2,<8.0", # 8.0.0 breaks our docs at the moment "myst_parser", "furo>=2023.08.17", # to allow profiling @@ -211,11 +200,8 @@ disallow_untyped_defs = false # the numbers can be re-calculated with: # mypy kazu docs conftest.py | cut -f 1,2 -d ':' | sed '$d' | sort | uniq | cut -f 1 -d ':' | uniq -c | sort --reverse | awk '{file_without_dot_py = substr($2, 0, length($2)-3); gsub("/", ".", file_without_dot_py); print " \""file_without_dot_py"\", # "$1}' module = [ - "kazu.distillation.models", # 10 - "kazu.linking.sapbert.train", # 4 - "kazu.utils.spacy_pipeline", # 2 - "kazu.distillation.tiny_transformers", # 2 - "kazu.distillation.metrics", # 1 + "kazu.linking.sapbert.train", # 4 + "kazu.utils.spacy_pipeline", # 2 ] # we had a bunch of these in the codebase before we moved to a 'strict' mypy config, and it was too many # to fix at that time for the payoff. Having overrides for the modules that would error rather than @@ -230,7 +216,6 @@ disallow_untyped_defs = false module = [ "kazu.data", # 13 "kazu.annotation.label_studio", # 9 - "kazu.distillation.models", # 8 "kazu.linking.sapbert.train", # 6 "kazu.steps.linking.post_processing.xref_manager", # 3 "kazu.steps.linking.post_processing.mapping_strategies.strategies", # 3 From 5eacab67ba17e1cc3766b49c3da26723f1f5a0ca Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 13:14:09 +0100 Subject: [PATCH 12/21] fix typo --- .github/workflows/precommit_and_docs_build_check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/precommit_and_docs_build_check.yml b/.github/workflows/precommit_and_docs_build_check.yml index 9778e828..a8645814 100644 --- a/.github/workflows/precommit_and_docs_build_check.yml +++ b/.github/workflows/precommit_and_docs_build_check.yml @@ -29,7 +29,7 @@ jobs: . /tmp/kazu-env/bin/activate python -m pip install --upgrade pip pip install --upgrade --upgrade-strategy eager --index-url https://download.pytorch.org/whl/cpu "torch>=2.0.0" - pip install -e ."[dev]" --cache-dir $PIP_CACHE_DIR --upgrade --upgrade-strategy eager" + pip install -e ."[dev]" --cache-dir $PIP_CACHE_DIR --upgrade --upgrade-strategy eager - name: Check precommit run: | From 3a2b1bad019ddc0266b9b920b6961d9247a4f80e Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 14:09:47 +0100 Subject: [PATCH 13/21] fix: use pre-commit < 4 * docformatter breaks pre-commit on version 4.0. Need to pin this until this issue is resolved: https://github.com/PyCQA/docformatter/issues/289 * in future when this is resolved also need to run `pre-commit migrate-config` --- pyproject.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6c78604d..4cf625f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ typed = [ # version 2.31.0.3 introduced # overly strict typing of requests.HttpError # which was fixed in 2.31.0.6 + "types-requests", "types-cachetools", "types-regex", "types-psutil", @@ -90,13 +91,14 @@ dev = [ "mypy", "vulture", "bump2version", - "pre-commit", + # docformatter breaks pre-commit on version 4.0. Need to pin this until this issue is resolved: https://github.com/PyCQA/docformatter/issues/289 + "pre-commit<4.0.0", "pytest", "pytest-mock", "pytest-cov", "pytest-timeout", "hypothesis", - "sphinx>=7.2,<8.0", # 8.0.0 breaks our docs at the moment + "sphinx>=7.2,<8.0", # 8.0.0 breaks our docs at the moment "myst_parser", "furo>=2023.08.17", # to allow profiling From 186844a8a53e6712acdc910bf292734b039f8a6b Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 14:21:00 +0100 Subject: [PATCH 14/21] fix: remove model training code that used pytorch lighting --- kazu/linking/__init__.py | 0 kazu/linking/sapbert/__init__.py | 0 kazu/linking/sapbert/train.py | 370 ------------------------------- 3 files changed, 370 deletions(-) delete mode 100644 kazu/linking/__init__.py delete mode 100644 kazu/linking/sapbert/__init__.py delete mode 100644 kazu/linking/sapbert/train.py diff --git a/kazu/linking/__init__.py b/kazu/linking/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kazu/linking/sapbert/__init__.py b/kazu/linking/sapbert/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kazu/linking/sapbert/train.py b/kazu/linking/sapbert/train.py deleted file mode 100644 index 720f4e2f..00000000 --- a/kazu/linking/sapbert/train.py +++ /dev/null @@ -1,370 +0,0 @@ -import logging -from dataclasses import dataclass -from typing import cast, Any, Optional, Union, NamedTuple - -import hydra -import numpy as np -import pandas as pd -import torch -from hydra.utils import instantiate -from omegaconf import DictConfig -from pydantic import BaseModel -from tokenizers import Encoding -from torch import optim -from torch.utils.data import DataLoader, Dataset -from transformers import ( - BatchEncoding, - PreTrainedTokenizerBase, -) -from transformers.file_utils import PaddingStrategy - -try: - from pytorch_metric_learning import miners, losses - from pytorch_lightning import Trainer, LightningModule - from pytorch_lightning.utilities.types import ( - STEP_OUTPUT, - TRAIN_DATALOADERS, - EVAL_DATALOADERS, - EPOCH_OUTPUT, - ) -except ImportError as e: - raise ImportError( - "Running the SapBERT model training code requires several additional dependencies to be installed.\n" - "We recommend running 'pip install kazu[model-training]' to get all model training" - " dependencies." - ) from e - -from kazu.utils.constants import HYDRA_VERSION_BASE -from kazu.utils.sapbert import SapBertHelper - -logger = logging.getLogger(__name__) - - -@dataclass -class SapbertDataCollatorWithPadding: - """Data collator to be used with HFSapbertPairwiseDataset.""" - - tokenizer: PreTrainedTokenizerBase - padding: Union[bool, str, PaddingStrategy] = True - max_length: Optional[int] = None - pad_to_multiple_of: Optional[int] = None - - def __call__( - self, features: list[dict[str, BatchEncoding]] - ) -> tuple[BatchEncoding, BatchEncoding]: - query_toks1 = [x["query_toks1"] for x in features] - query_toks1_enc = self.tokenizer.pad( - query_toks1, - padding=self.padding, - max_length=self.max_length, - pad_to_multiple_of=self.pad_to_multiple_of, - return_tensors="pt", - ) - query_toks2 = [x["query_toks2"] for x in features] - query_toks2_enc = self.tokenizer.pad( - query_toks2, - padding=self.padding, - max_length=self.max_length, - pad_to_multiple_of=self.pad_to_multiple_of, - return_tensors="pt", - ) - - return query_toks1_enc, query_toks2_enc - - -class HFSapbertPairwiseDataset(Dataset): - """Dataset used for training SapBert.""" - - def __getitem__(self, index: int) -> dict[str, Any]: - query_toks1 = { - "input_ids": self.encodings_1.data["input_ids"][index], - "token_type_ids": self.encodings_1.data["token_type_ids"][index], - "labels": self.encodings_1.data["labels"][index], - "attention_mask": self.encodings_1.data["attention_mask"][index], - } - query_toks2 = { - "input_ids": self.encodings_2.data["input_ids"][index], - "token_type_ids": self.encodings_2.data["token_type_ids"][index], - "labels": self.encodings_2.data["labels"][index], - "attention_mask": self.encodings_2.data["attention_mask"][index], - } - - return {"query_toks1": query_toks1, "query_toks2": query_toks2} - - def __init__(self, encodings_1: BatchEncoding, encodings_2: BatchEncoding, labels: np.ndarray): - """ - :param encodings_1: encodings for example 1 - :param encodings_2: encodings for example 2 - :param labels: labels i.e. knowledgebase identifier for both encodings, as an int - """ - encodings_1["labels"] = labels - encodings_2["labels"] = labels - self.encodings_1 = encodings_1 - self.encodings_2 = encodings_2 - - def __len__(self): - encodings = cast(list[Encoding], self.encodings_1.encodings) - return len(encodings) - - -class SapbertTrainingParams(BaseModel): - topk: int # score this many nearest neighbours against target - lr: float # learning rate - weight_decay: float - miner_margin: float # passed to TripletMarginMiner - type_of_triplets: str # passed to TripletMarginMiner - train_file: str # a parquet file with three columns - syn1, syn2 and id - train_batch_size: int - num_workers: int # passed to dataloaders - - -class SapbertEvaluationDataset(NamedTuple): - """To evaluate a given embedding model, we need a query datasource (i.e. things that - need to be linked)] and an ontology datasource (i.e. things that we need to generate - an embedding space for, that can be queried against) each should have three columns: - - default_label (text), iri (ontology id) and source (ontology name) - """ - - query_source: pd.DataFrame - ontology_source: pd.DataFrame - - -class Candidate(NamedTuple): - """A knowledgebase entry.""" - - default_label: str - iri: str - correct: bool - - -class GoldStandardExample(NamedTuple): - gold_default_label: str - gold_iri: str - candidates: list[ - Candidate - ] # candidates (aka nearest neighbours) associate with this gold instance - - -class SapbertEvaluationDataManager: - """Manages the loading/parsing of multiple evaluation datasets. Each dataset should - have two sources, a query source and an ontology source. these are then converted - into data loaders, while maintaining a reference to the embedding metadata that - should be used for evaluation. - - self.dataset is dict[dataset_name, SapbertEvaluationDataset] after construction - """ - - def __init__(self, sources: dict[str, list[str]], debug: bool = False): - self.datasets: dict[str, SapbertEvaluationDataset] = {} - for source_name, ( - query_source_path, - ontology_source_path, - ) in sources.items(): - query_source = pd.read_parquet(query_source_path) - ontology_source = pd.read_parquet(ontology_source_path) - if debug: - query_source = query_source.head(10) - ontology_source = ontology_source.sample(frac=1.0).head(100) - self.datasets[source_name] = SapbertEvaluationDataset( - query_source=query_source, ontology_source=ontology_source - ) - - -class PLSapbertModel(LightningModule): - def __init__( - self, - model_name_or_path: str, - sapbert_training_params: Optional[SapbertTrainingParams] = None, - sapbert_evaluation_manager: Optional[SapbertEvaluationDataManager] = None, - *args: Any, - **kwargs: Any, - ): - """ - :param model_name_or_path: passed to AutoModel.from_pretrained - :param sapbert_training_params: optional SapbertTrainingParams, only needed if training a model - :param sapbert_evaluation_manager: optional SapbertEvaluationDataManager, only needed if training a model - :param args: passed to LightningModule - :param kwargs: passed to LightningModule - """ - - super().__init__(*args, **kwargs) - self.sapbert_helper = SapBertHelper(model_name_or_path) - self.model = self.sapbert_helper.model - self.sapbert_evaluation_manager = sapbert_evaluation_manager - self.sapbert_training_params = sapbert_training_params - if sapbert_training_params is not None: - self.loss = losses.MultiSimilarityLoss(alpha=1, beta=60, base=0.5) - self.miner = miners.TripletMarginMiner( - margin=sapbert_training_params.miner_margin, - type_of_triplets=sapbert_training_params.type_of_triplets, - ) - - def configure_optimizers(self): - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.configure_optimizers - `\\ .""" - assert self.sapbert_training_params is not None - optimizer = optim.AdamW( - [ - {"params": self.model.parameters()}, - ], - lr=self.sapbert_training_params.lr, - weight_decay=self.sapbert_training_params.weight_decay, - ) - return optimizer - - def forward(self, batch: BatchEncoding) -> dict[int, torch.Tensor]: - """For inference. - - :param batch: standard bert input, with an additional 'indices' for representing - the location of the embedding - :return: - """ - return self.sapbert_helper.get_prediction_from_batch(self.model, batch) - - def training_step(self, batch: Any, batch_idx: int, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.training_step - `\\ .""" - query_toks1, query_toks2 = batch - # labels should be identical, so we only need one - labels = query_toks1.pop("labels") - query_toks2.pop("labels") - last_hidden_state1 = self.model(**query_toks1, return_dict=True).last_hidden_state - last_hidden_state2 = self.model(**query_toks2, return_dict=True).last_hidden_state - query_embed1 = last_hidden_state1[:, 0] # query : [batch_size, hidden] - query_embed2 = last_hidden_state2[:, 0] # query : [batch_size, hidden] - query_embed = torch.cat([query_embed1, query_embed2], dim=0) - labels = torch.cat([labels, labels], dim=0) - hard_pairs = self.miner(query_embed, labels) - return self.loss(query_embed, labels, hard_pairs) # type: ignore [no-any-return] # no clear type info for the loss - - def train_dataloader(self) -> TRAIN_DATALOADERS: - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.train_dataloader - `\\ .""" - assert self.sapbert_training_params is not None - training_df = pd.read_parquet(self.sapbert_training_params.train_file) - labels = training_df["id"].astype("category").cat.codes.to_numpy() - encodings_1 = self.tokeniser(training_df["syn1"].tolist()) - encodings_2 = self.tokeniser(training_df["syn2"].tolist()) - encodings_1["labels"] = labels - encodings_2["labels"] = labels - train_set = HFSapbertPairwiseDataset( - labels=labels, encodings_1=encodings_1, encodings_2=encodings_2 - ) - train_loader = DataLoader( - train_set, - batch_size=self.sapbert_training_params.train_batch_size, - shuffle=True, - num_workers=self.sapbert_training_params.num_workers, - collate_fn=SapbertDataCollatorWithPadding( - self.tokeniser, padding=PaddingStrategy.LONGEST - ), - ) - - return train_loader - - def val_dataloader(self) -> EVAL_DATALOADERS: - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.val_dataloader - `\\ .""" - dataloaders = [] - assert self.sapbert_evaluation_manager is not None - assert self.sapbert_training_params is not None - for query_source, ontology_source in self.sapbert_evaluation_manager.datasets.values(): - query_dataloader = self.sapbert_helper.get_embedding_dataloader_from_strings( - texts=query_source["default_label"].tolist(), - batch_size=self.sapbert_training_params.train_batch_size, - num_workers=self.sapbert_training_params.num_workers, - ) - ontology_dataloader = self.sapbert_helper.get_embedding_dataloader_from_strings( - texts=ontology_source["default_label"].tolist(), - batch_size=self.sapbert_training_params.train_batch_size, - num_workers=self.sapbert_training_params.num_workers, - ) - dataloaders.append(query_dataloader) - dataloaders.append(ontology_dataloader) - - return dataloaders - - def validation_step( - self, batch: Any, batch_idx: int, dataset_idx: int - ) -> Optional[STEP_OUTPUT]: - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.validation_step - `\\ .""" - return self(batch) # type: ignore [no-any-return] # no type info from Pytorch Lightning - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: - """Implementation of - :external+pytorch_lightning:ref:`LightningModule.predict_step - `\\ .""" - return self(batch) - - def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, list[EPOCH_OUTPUT]]) -> None: - """Lightning override generate new embeddings for each - :attr:`SapbertEvaluationDataset.ontology_source` and query them with - :attr:`SapbertEvaluationDataset.query_source` - - :param outputs: - :return: - """ - assert self.sapbert_evaluation_manager is not None - raise NotImplementedError - - def log_results(self, dataset_name, metrics): - for key, val in metrics.items(): - if key.startswith("acc"): - self.log(key, value=val, rank_zero_only=True) - logger.info(f"{dataset_name}: {key}, {val}") - - def get_candidate_dict(self, np_candidates: pd.DataFrame, golden_iri: str) -> list[Candidate]: - """Convert rows in a dataframe representing candidate KB entries into a - corresponding :class:`Candidate` per row. - - :param np_candidates: - :param golden_iri: - :return: - """ - candidates_filtered = [] - for i, np_candidate_row in np_candidates.iterrows(): - candidates_filtered.append( - Candidate( - default_label=np_candidate_row["default_label"], - iri=np_candidate_row["iri"], - correct=np_candidate_row["iri"] == golden_iri, - ) - ) - return candidates_filtered - - def evaluate_topk_acc(self, queries: list[GoldStandardExample]) -> dict[str, float]: - """Get a dictionary of accuracy results at different levels of k (nearest - neighbours) - - :param queries: - :return: - """ - k = len(queries[0].candidates) - result = {} - for i in range(0, k): - hit = 0 - for query in queries: - candidates = query.candidates[: i + 1] # to get acc@(i+1) - if any(candidate.correct for candidate in candidates): - hit += 1 - result["acc{}".format(i + 1)] = hit / len(queries) - - return result - - -@hydra.main(version_base=HYDRA_VERSION_BASE, config_path="../../../conf", config_name="config") -def start(cfg: DictConfig) -> None: - trainer: Trainer = instantiate(cfg.SapBertTraining.trainer) - model: PLSapbertModel = instantiate(cfg.SapBertTraining.model) - trainer.fit(model) - - -if __name__ == "__main__": - start() From f0d4c4b1316c5a9b51cf081b073f6676b80e577b Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 14:27:07 +0100 Subject: [PATCH 15/21] fix: add requirement back in --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4cf625f0..bb870a69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,8 @@ webserver = [ "PyJWT>=2.0.0", ] llm = [ - "openai", # Required for OpenAI API + "google-cloud-aiplatform", # Required for vertex API + "openai", # Required for OpenAI API ] typed = [ # version 2.31.0.3 introduced @@ -202,8 +203,7 @@ disallow_untyped_defs = false # the numbers can be re-calculated with: # mypy kazu docs conftest.py | cut -f 1,2 -d ':' | sed '$d' | sort | uniq | cut -f 1 -d ':' | uniq -c | sort --reverse | awk '{file_without_dot_py = substr($2, 0, length($2)-3); gsub("/", ".", file_without_dot_py); print " \""file_without_dot_py"\", # "$1}' module = [ - "kazu.linking.sapbert.train", # 4 - "kazu.utils.spacy_pipeline", # 2 + "kazu.utils.spacy_pipeline", # 2 ] # we had a bunch of these in the codebase before we moved to a 'strict' mypy config, and it was too many # to fix at that time for the payoff. Having overrides for the modules that would error rather than From c5e48c3fecb28eb787b27d5877b685e05f1072fc Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 14:58:18 +0100 Subject: [PATCH 16/21] fix: remove further mentions of training --- docs/conf.py | 5 ---- kazu/conf/SapBertTraining/default.yaml | 29 ------------------------ kazu/utils/model_pack_build_logging.conf | 7 +----- pyproject.toml | 2 -- 4 files changed, 1 insertion(+), 42 deletions(-) delete mode 100644 kazu/conf/SapBertTraining/default.yaml diff --git a/docs/conf.py b/docs/conf.py index 48a00f62..14ede3c4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -125,7 +125,6 @@ intersphinx_mapping = { "transformers": ("https://huggingface.co/docs/transformers/main/en/", None), - "pytorch_lightning": ("https://lightning.ai/docs/pytorch/stable/", None), "torch": ("https://pytorch.org/docs/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "pandas": ("https://pandas.pydata.org/docs/", None), @@ -207,16 +206,12 @@ def linkcode_resolve(domain: str, info: dict[str, Any]) -> Union[str, None]: "transformers.tokenization_utils_fast.PreTrainedTokenizerFast": "transformers.PreTrainedTokenizerFast", "transformers.utils.generic.PaddingStrategy": "transformers.utils.PaddingStrategy", "lightning_fabric.plugins.io.checkpoint_io.CheckpointIO": "lightning.pytorch.plugins.io.CheckpointIO", - "pytorch_lightning.core.module.LightningModule": "lightning.pytorch.core.LightningModule", - "pytorch_lightning.trainer.trainer.Trainer": "lightning.pytorch.trainer.trainer.Trainer", "urllib3.util.retry.Retry": "urllib3.util.Retry", "scipy.sparse._csr.csr_matrix": "scipy.sparse.csr_matrix", } nitpick_ignore = [ - # this doesn't exist anymore in lightning 2.0, it becomes on_validation_epoch_end, and there's some migration work for changing to it - ("py:meth", "pytorch_lightning.core.LightningModule.validation_epoch_end"), # this doesn't appear to have an entry in the transformers docs for some reason. ("py:class", "transformers.models.bert.modeling_bert.BertPreTrainedModel"), # the kazu.utils.grouping.Key TypeVar tries to generate this automatically. diff --git a/kazu/conf/SapBertTraining/default.yaml b/kazu/conf/SapBertTraining/default.yaml deleted file mode 100644 index fa891355..00000000 --- a/kazu/conf/SapBertTraining/default.yaml +++ /dev/null @@ -1,29 +0,0 @@ -model: - _target_: kazu.linking.sapbert.train.PLSapbertModel - model_name_or_path: ${oc.env:KAZU_MODEL_PACK}/sapbert/SapBERT-from-PubMedBERT-fulltext - sapbert_training_params: - _target_: kazu.linking.sapbert.train.SapbertTrainingParams - topk: 5 - lr: 2e-5 - weight_decay: 0.01 - miner_margin: 0.2 - type_of_triplets: "all" - train_file: ??? - train_batch_size: 256 - num_workers: 0 - sapbert_evaluation_manager: - _target_: kazu.linking.sapbert.train.SapbertEvaluationDataManager - debug: true - sources: - bc5cdr-disease: - - ??? - - ??? - mondo: - - ??? - - ??? -trainer: - _target_: pytorch_lightning.Trainer - enable_progress_bar: False - num_sanity_val_steps: 2 - gpus: 0 - accelerator: ddp diff --git a/kazu/utils/model_pack_build_logging.conf b/kazu/utils/model_pack_build_logging.conf index 4f462e31..73e3e8e3 100644 --- a/kazu/utils/model_pack_build_logging.conf +++ b/kazu/utils/model_pack_build_logging.conf @@ -1,5 +1,5 @@ [loggers] -keys=root,pytorch_lightning +keys=root [handlers] keys=consoleHandler,nullHandler @@ -11,11 +11,6 @@ keys=simpleFormatter level=INFO handlers=consoleHandler -[logger_pytorch_lightning] -handlers=nullHandler -propagate=0 -qualname=pytorch_lightning - [handler_consoleHandler] class=StreamHandler level=INFO diff --git a/pyproject.toml b/pyproject.toml index bb870a69..69c489de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -178,7 +178,6 @@ module = [ "tokenizers.*", "stanza.*", "sklearn.*", - "pytorch_metric_learning.*", "srsly.*", "py4j.*", "diskcache.*", @@ -218,7 +217,6 @@ disallow_untyped_defs = false module = [ "kazu.data", # 13 "kazu.annotation.label_studio", # 9 - "kazu.linking.sapbert.train", # 6 "kazu.steps.linking.post_processing.xref_manager", # 3 "kazu.steps.linking.post_processing.mapping_strategies.strategies", # 3 "kazu.web.server", # 2 From 3917f6aa599fa9258f91791425ab439139b2d20b Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 7 Oct 2024 16:20:52 +0100 Subject: [PATCH 17/21] pin spacy --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 69c489de..8cdf59a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ classifiers = [ "Topic :: Scientific/Engineering :: Medical Science Apps.", ] dependencies = [ - "spacy>=3.2.2", + # pinning spacy because larger versions cause dependcy issues with thinc and various packages on 7/10/2024 + "spacy==3.2.2", # note to kazu devs: if the torch version changes in future, # we need to change for CI jobs where we specify this install # elsewhere because of the CPU wheel needing a different pypi index. @@ -34,7 +35,7 @@ dependencies = [ "scikit-learn>=0.24.0", # scipy 1.12.0 introduced many changes to the sparse matrices api. https://docs.scipy.org/doc/scipy/reference/sparse.html#module-scipy.sparse # This is causing our acceptance tests to fail. Pinning to <1.12.0 until it's confirmed other libraries (e.g. sk-learn) don't have issues. - "scipy", + "scipy<1.12.0", "regex>=2020.1.7", "psutil>=5.3.0", "cachetools>=5.2.0", From 7318dfee74aacc97a4dd4358cd149f9046395839 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Tue, 8 Oct 2024 10:11:15 +0100 Subject: [PATCH 18/21] fix: broken spacy version on python 3.11 --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8cdf59a0..7e3387f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,8 @@ classifiers = [ ] dependencies = [ # pinning spacy because larger versions cause dependcy issues with thinc and various packages on 7/10/2024 - "spacy==3.2.2", + # Also needs to be at least 3.4.2 on python 3.11 (see https://github.com/explosion/spaCy/issues/12697) + "spacy==3.7.0", # note to kazu devs: if the torch version changes in future, # we need to change for CI jobs where we specify this install # elsewhere because of the CPU wheel needing a different pypi index. @@ -59,7 +60,7 @@ webserver = [ # UPDATE: 0.109.0 also no longer works. It seems like there's some incompatibility with ray serve here, so pinning to # last known good version "fastapi==0.108.0", - "pydantic<2.0", + "pydantic<2.0", # for the pinned spacy version pydantic should be > 1.8 as well. "ray[serve]>=2.0.0", "PyJWT>=2.0.0", ] From 809b870f915263d5e79a681061fe7f9f14c25c09 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Tue, 8 Oct 2024 11:33:53 +0100 Subject: [PATCH 19/21] update sphinx --- docs/conf.py | 1 - pyproject.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 14ede3c4..a771867c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -132,7 +132,6 @@ "sklearn": ("https://scikit-learn.org/stable/", None), "diskcache": ("https://grantjenks.com/docs/diskcache/", None), "rdflib": ("https://rdflib.readthedocs.io/en/stable/", None), - "pymongo": ("https://pymongo.readthedocs.io/en/stable/", None), # pymongo includes bson "bson": ("https://pymongo.readthedocs.io/en/stable/", None), "requests": ("https://requests.readthedocs.io/en/latest/", None), diff --git a/pyproject.toml b/pyproject.toml index 7e3387f4..9e818c0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ dev = [ "pytest-cov", "pytest-timeout", "hypothesis", - "sphinx>=7.2,<8.0", # 8.0.0 breaks our docs at the moment + "sphinx>=7.2", "myst_parser", "furo>=2023.08.17", # to allow profiling From e384014fd595aca7a915c299786c7df49380779b Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Tue, 8 Oct 2024 11:43:25 +0100 Subject: [PATCH 20/21] fix: pin sphinx to avoid ssl errors --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9e818c0c..eb543d80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ dev = [ "pytest-cov", "pytest-timeout", "hypothesis", - "sphinx>=7.2", + "sphinx>=7.2,<8.0", # 8.0.0 leads to strange SSL errors, both locally and on the runner", "myst_parser", "furo>=2023.08.17", # to allow profiling From 104437b93e9637e63ace25cc10d9dda73f8637af Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Tue, 8 Oct 2024 11:49:04 +0100 Subject: [PATCH 21/21] fix: remove reference to removed code --- kazu/conf/config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/kazu/conf/config.yaml b/kazu/conf/config.yaml index 50f05a06..e7fd60af 100644 --- a/kazu/conf/config.yaml +++ b/kazu/conf/config.yaml @@ -26,7 +26,6 @@ defaults: - SpacyNerStep: default - TransformersModelForTokenClassificationNerStep: default - DictionaryEntityLinkingStep: default - - SapBertTraining: default - Pipeline: default - OpsinStep: default - SethStep: default