Skip to content

Commit

Permalink
Create custom NotebookCallback subclass for embedding_loss, etc. (#557)
Browse files Browse the repository at this point in the history
* Create custom NotebookCallback subclass for embedding_loss, etc.

* Move notebook code into separate file so IPython isn't required

* Add docstring for SetFitNotebookProgressCallback
  • Loading branch information
tomaarsen authored Sep 19, 2024
1 parent 223afb6 commit 35755c6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 10 deletions.
52 changes: 52 additions & 0 deletions src/setfit/notebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import re

from transformers.utils.notebook import NotebookProgressCallback


class SetFitNotebookProgressCallback(NotebookProgressCallback):
"""
A variation of NotebookProgressCallback that accepts logs/metrics other than "loss" and "eval_loss".
In particular, it accepts "embedding_loss", "aspect_embedding_loss", and "polarity_embedding_loss"
and the corresponding metrics for the validation set.
"""
def on_log(self, *args, logs=None, **kwargs):
if logs is not None:
logs = {key if key != "embedding_loss" else "loss": value for key, value in logs.items()}
return super().on_log(*args, logs=logs, **kwargs)

def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if self.training_tracker is not None:
values = {"Training Loss": "No log", "Validation Loss": "No log"}
for log in reversed(state.log_history):
if loss_logs := {
key for key in log if key in ("embedding_loss", "aspect_embedding_loss", "polarity_embedding_loss")
}:
values["Training Loss"] = log[loss_logs.pop()]
break

if self.first_column == "Epoch":
values["Epoch"] = int(state.epoch)
else:
values["Step"] = state.global_step
metric_key_prefix = "eval"
for k in metrics:
if k.endswith("_loss"):
metric_key_prefix = re.sub(r"\_loss$", "", k)
_ = metrics.pop("total_flos", None)
_ = metrics.pop("epoch", None)
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
for k, v in metrics.items():
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits[1:]])
if name in ("Embedding Loss", "Aspect Embedding Loss", "Polarity Embedding Loss"):
# Single dataset
name = "Validation Loss"
values[name] = v
self.training_tracker.write_line(values)
self.training_tracker.remove_child()
self.prediction_bar = None
# Evaluation takes a long time so we should force the next update.
self._force_next_update = True
19 changes: 9 additions & 10 deletions src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sklearn.preprocessing import LabelEncoder
from torch import nn
from transformers.integrations import CodeCarbonCallback
from transformers.trainer_callback import DefaultFlowCallback, IntervalStrategy, ProgressCallback, TrainerCallback
from transformers.trainer_callback import IntervalStrategy, TrainerCallback
from transformers.trainer_utils import HPSearchBackend, default_compute_objective, number_of_arguments, set_seed
from transformers.utils.import_utils import is_in_notebook

Expand All @@ -34,15 +34,6 @@
logger = logging.get_logger(__name__)


DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

if is_in_notebook():
from transformers.utils.notebook import NotebookProgressCallback

DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback


class BCSentenceTransformersTrainer(SentenceTransformerTrainer):
"""
Subclass of SentenceTransformerTrainer that is backwards compatible with the SetFit API.
Expand All @@ -64,6 +55,14 @@ def __init__(
if isinstance(callback, STModelCardCallback):
self.remove_callback(callback)

if is_in_notebook():
from transformers.utils.notebook import NotebookProgressCallback

from setfit.notebook import SetFitNotebookProgressCallback

if self.pop_callback(NotebookProgressCallback):
self.add_callback(SetFitNotebookProgressCallback)

def overwritten_call_event(self, event, args, state, control, **kwargs):
for callback in self.callbacks:
result = getattr(callback, event)(
Expand Down

0 comments on commit 35755c6

Please sign in to comment.