Skip to content

Commit

Permalink
refactor: save out predictions in eval script
Browse files Browse the repository at this point in the history
  • Loading branch information
paluchasz committed Dec 6, 2024
1 parent 02bc292 commit 068511c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
9 changes: 9 additions & 0 deletions kazu/training/evaluate_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from hydra.utils import instantiate
from omegaconf import DictConfig

from kazu.data import Document
from kazu.pipeline import Pipeline
from kazu.steps.ner.hf_token_classification import (
TransformersModelForTokenClassificationNerStep,
Expand All @@ -30,6 +31,13 @@
from kazu.utils.constants import HYDRA_VERSION_BASE


def save_out_predictions(output_dir: Path, documents: list[Document]) -> None:
for doc in documents:
file_path = output_dir / f"{doc.idx}.json"
with file_path.open("w") as f:
f.write(doc.to_json())


@hydra.main(
version_base=HYDRA_VERSION_BASE,
config_path=str(
Expand Down Expand Up @@ -64,6 +72,7 @@ def main(cfg: DictConfig) -> None:
pipeline(documents)
print(f"Predicted {len(documents)} documents in {time.time() - start:.2f} seconds.")

save_out_predictions(Path(cfg.predictions_dir), documents)
print("Calculating metrics")
metrics, _ = calculate_metrics(0, documents, label_list)
with open(Path(prediction_config.path) / "test_metrics.json", "w") as file:
Expand Down
1 change: 1 addition & 0 deletions scripts/examples/conf/multilabel_ner_evaluate/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ prediction_config:
device: cpu
architecture: bert
use_multilabel: true
predictions_dir: ???
css_colors:
- "#000000" # Black
- "#FF0000" # Red
Expand Down

0 comments on commit 068511c

Please sign in to comment.