diff --git a/kazu/steps/ner/hf_token_classification.py b/kazu/steps/ner/hf_token_classification.py index d785827d..fc793ac5 100644 --- a/kazu/steps/ner/hf_token_classification.py +++ b/kazu/steps/ner/hf_token_classification.py @@ -1,27 +1,29 @@ import logging from collections.abc import Iterator -from typing import Optional, cast, Any, Iterable +from typing import Any, Iterable, Optional, cast import torch from torch import Tensor, softmax from torch.utils.data import DataLoader, IterableDataset from transformers import ( - AutoModelForTokenClassification, AutoConfig, + AutoModelForTokenClassification, AutoTokenizer, + BatchEncoding, DataCollatorWithPadding, PreTrainedTokenizerBase, - BatchEncoding, ) from transformers.file_utils import PaddingStrategy -from kazu.data import Section, Document +from kazu.data import Document, Section from kazu.steps import Step, document_batch_step from kazu.steps.ner.entity_post_processing import NonContiguousEntitySplitter -from kazu.steps.ner.tokenized_word_processor import TokenizedWordProcessor, TokenizedWord +from kazu.steps.ner.tokenized_word_processor import ( + TokenizedWord, + TokenizedWordProcessor, +) from kazu.utils.utils import documents_to_document_section_batch_encodings_map - logger = logging.getLogger(__name__) @@ -288,26 +290,28 @@ def get_list_of_batch_encoding_frames_for_section( def get_multilabel_activations(self, loader: DataLoader) -> Tensor: """Get a tensor consisting of confidences for labels in a multi label - classification context. + classification context. Output tensor is of shape (n_samples, + max_sequence_length, n_labels). :param loader: :return: """ with torch.no_grad(): results = torch.cat( - tuple(self.model(**batch.to(self.device)).logits for batch in loader) + tuple(self.model(**batch.to(self.device)).logits.to("cpu") for batch in loader) ).to(self.device) return results.heaviside(torch.tensor([0.0]).to(self.device)).int().to("cpu") def get_single_label_activations(self, loader: DataLoader) -> Tensor: """Get a tensor consisting of one hot binary classifications in a single label - classification context. + classification context. Output tensor is of shape (n_samples, + max_sequence_length, n_labels). :param loader: :return: """ with torch.no_grad(): results = torch.cat( - tuple(self.model(**batch.to(self.device)).logits for batch in loader) + tuple(self.model(**batch.to(self.device)).logits.to("cpu") for batch in loader) ) return softmax(results, dim=-1).to("cpu") diff --git a/kazu/training/config.py b/kazu/training/config.py index dd9d6e11..8d0209fe 100644 --- a/kazu/training/config.py +++ b/kazu/training/config.py @@ -45,6 +45,8 @@ class TrainingConfig: architecture: str = "bert" #: fraction of epoch to complete before evaluations begin epoch_completion_fraction_before_evals: float = 0.75 + #: The random seed to use + seed: int = 42 @dataclass diff --git a/kazu/training/evaluate_script.py b/kazu/training/evaluate_script.py index 1c441d8f..349b58c0 100644 --- a/kazu/training/evaluate_script.py +++ b/kazu/training/evaluate_script.py @@ -8,6 +8,7 @@ from pathlib import Path import hydra +import tqdm from hydra.utils import instantiate from omegaconf import DictConfig @@ -19,6 +20,7 @@ from kazu.steps.ner.tokenized_word_processor import TokenizedWordProcessor from kazu.training.config import PredictionConfig from kazu.training.modelling_utils import ( + chunks, create_wrapper, doc_yielder, get_label_list_from_model, @@ -69,10 +71,16 @@ def main(cfg: DictConfig) -> None: documents = move_entities_to_metadata(documents) print("Predicting with the KAZU pipeline") start = time.time() - pipeline(documents) + docs_in_batch = 10 + for documents_batch in tqdm.tqdm( + chunks(documents, docs_in_batch), total=len(documents) // docs_in_batch + ): + pipeline(documents_batch) print(f"Predicted {len(documents)} documents in {time.time() - start:.2f} seconds.") + Path(cfg.predictions_dir).mkdir(parents=True, exist_ok=True) save_out_predictions(Path(cfg.predictions_dir), documents) + print("Calculating metrics") metrics, _ = calculate_metrics(0, documents, label_list) with open(Path(prediction_config.path) / "test_metrics.json", "w") as file: diff --git a/kazu/training/modelling_utils.py b/kazu/training/modelling_utils.py index 95f33350..de44310f 100644 --- a/kazu/training/modelling_utils.py +++ b/kazu/training/modelling_utils.py @@ -1,6 +1,9 @@ +import copy import json +import logging +from collections.abc import Iterable from pathlib import Path -from typing import Iterable, Optional +from typing import Any, Optional, Union from hydra.utils import instantiate from omegaconf import DictConfig @@ -9,12 +12,17 @@ LabelStudioAnnotationView, LabelStudioManager, ) -from kazu.data import ENTITY_OUTSIDE_SYMBOL, Document, Entity, Section -from kazu.training.train_multilabel_ner import ( - LSManagerViewWrapper, +from kazu.data import ( + ENTITY_OUTSIDE_SYMBOL, + PROCESSING_EXCEPTION, + Document, + Entity, + Section, ) from kazu.utils.utils import PathLike +logger = logging.getLogger(__name__) + def doc_yielder(path: PathLike) -> Iterable[Document]: for file in Path(path).iterdir(): @@ -46,6 +54,12 @@ def test_doc_yielder() -> Iterable[Document]: yield doc +def chunks(lst: list[Any], n: int) -> Iterable[list[Any]]: + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + def get_label_list(path: PathLike) -> list[str]: label_list = set() for doc in doc_yielder(path): @@ -64,6 +78,49 @@ def get_label_list_from_model(model_config_path: PathLike) -> list[str]: return label_list +class LSManagerViewWrapper: + def __init__(self, view: LabelStudioAnnotationView, ls_manager: LabelStudioManager): + self.ls_manager = ls_manager + self.view = view + + def get_gold_ents_for_side_by_side_view(self, docs: list[Document]) -> list[list[Document]]: + result = [] + for doc in docs: + doc_cp = copy.deepcopy(doc) + if PROCESSING_EXCEPTION in doc_cp.metadata: + logger.error(doc.metadata[PROCESSING_EXCEPTION]) + break + for section in doc_cp.sections: + gold_ents = [] + for ent in section.metadata.get("gold_entities", []): + if isinstance(ent, dict): + ent = Entity.from_dict(ent) + gold_ents.append(ent) + section.entities = gold_ents + result.append([doc_cp, doc]) + return result + + def update( + self, docs: list[Document], global_step: Union[int, str], has_gs: bool = True + ) -> None: + ls_manager = LabelStudioManager( + headers=self.ls_manager.headers, + project_name=f"{self.ls_manager.project_name}_test_{global_step}", + ) + ls_manager.delete_project_if_exists() + ls_manager.create_linking_project() + if not docs: + logger.info("no results to represent yet") + return + if has_gs: + side_by_side = self.get_gold_ents_for_side_by_side_view(docs) + ls_manager.update_view(self.view, side_by_side) + ls_manager.update_tasks(side_by_side) + else: + ls_manager.update_view(self.view, docs) + ls_manager.update_tasks(docs) + + def create_wrapper(cfg: DictConfig, label_list: list[str]) -> Optional[LSManagerViewWrapper]: if cfg.get("label_studio_manager") and cfg.get("css_colors"): ls_manager: LabelStudioManager = instantiate(cfg.label_studio_manager) diff --git a/kazu/training/train_multilabel_ner.py b/kazu/training/train_multilabel_ner.py index 584195ac..46c2fc61 100644 --- a/kazu/training/train_multilabel_ner.py +++ b/kazu/training/train_multilabel_ner.py @@ -27,12 +27,10 @@ ) from kazu.annotation.acceptance_test import aggregate_ner_results, score_sections -from kazu.annotation.label_studio import LabelStudioAnnotationView, LabelStudioManager from kazu.data import ( ENTITY_OUTSIDE_SYMBOL, PROCESSING_EXCEPTION, Document, - Entity, NumericMetric, Section, ) @@ -47,61 +45,11 @@ DebertaForMultiLabelTokenClassification, DistilBertForMultiLabelTokenClassification, ) +from kazu.training.modelling_utils import LSManagerViewWrapper, chunks logger = logging.getLogger(__name__) -def chunks(lst: list[Any], n: int) -> Iterable[list[Any]]: - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - -class LSManagerViewWrapper: - def __init__(self, view: LabelStudioAnnotationView, ls_manager: LabelStudioManager): - self.ls_manager = ls_manager - self.view = view - - def get_gold_ents_for_side_by_side_view(self, docs: list[Document]) -> list[list[Document]]: - result = [] - for doc in docs: - doc_cp = copy.deepcopy(doc) - if PROCESSING_EXCEPTION in doc_cp.metadata: - logger.error(doc.metadata[PROCESSING_EXCEPTION]) - break - for section in doc_cp.sections: - gold_ents = [] - for ent in section.metadata.get("gold_entities", []): - if isinstance(ent, dict): - ent = Entity.from_dict(ent) - gold_ents.append(ent) - section.entities = gold_ents - result.append([doc_cp, doc]) - return result - - def update( - self, test_docs: list[Document], global_step: Union[int, str], has_gs: bool = True - ) -> None: - ls_manager = LabelStudioManager( - headers=self.ls_manager.headers, - project_name=f"{self.ls_manager.project_name}_test_{global_step}", - ) - - ls_manager.delete_project_if_exists() - ls_manager.create_linking_project() - docs_subset = random.sample(test_docs, min([len(test_docs), 100])) - if not docs_subset: - logger.info("no results to represent yet") - return - if has_gs: - side_by_side = self.get_gold_ents_for_side_by_side_view(docs_subset) - ls_manager.update_view(self.view, side_by_side) - ls_manager.update_tasks(side_by_side) - else: - ls_manager.update_view(self.view, docs_subset) - ls_manager.update_tasks(docs_subset) - - @dataclasses.dataclass class SavedModel: path: Path @@ -390,6 +338,7 @@ def __init__( self.label_list = label_list self.pretrained_model_name_or_path = pretrained_model_name_or_path self.keys_to_use = _select_keys_to_use(self.training_config.architecture) + random.seed(training_config.seed) def _write_to_tensorboard( self, global_step: int, main_tag: str, tag_scalar_dict: dict[str, NumericMetric] @@ -413,7 +362,8 @@ def evaluate_model( model_test_docs = self._process_docs(model) if self.ls_wrapper: - self.ls_wrapper.update(model_test_docs, global_step) + sample_test_docs = random.sample(model_test_docs, min([len(model_test_docs), 100])) + self.ls_wrapper.update(sample_test_docs, global_step) all_results, tensorboad_loggables = calculate_metrics( epoch_loss, model_test_docs, self.label_list