Skip to content

Commit

Permalink
add unit mixer docs
Browse files Browse the repository at this point in the history
  • Loading branch information
paxcema committed Jan 22, 2024
1 parent 5409551 commit 9b59d6b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion lightwood/mixer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BaseMixer:

def __init__(self, stop_after: float):
"""
:param stop_after: Time budget to train this mixer.
:param stop_after: Time budget (in seconds) to train this mixer.
"""
self.stop_after = stop_after
self.supports_proba = False
Expand Down
25 changes: 17 additions & 8 deletions lightwood/mixer/unit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
"""
2021.07.16
For encoders that already fine-tune on the targets (namely text)
the unity mixer just arg-maxes the output of the encoder.
"""

from typing import List, Optional

import torch
Expand All @@ -19,19 +12,35 @@

class Unit(BaseMixer):
def __init__(self, stop_after: float, target_encoder: BaseEncoder):
"""
The "Unit" mixer serves as a simple wrapper around a target encoder, essentially borrowing
the encoder's functionality for predictions. In other words, it simply arg-maxes the output of the encoder
Used with encoders that already fine-tune on the targets (namely, pre-trained text ML models).
Attributes:
:param target_encoder: An instance of a Lightwood BaseEncoder. This encoder is used to decode predictions.
:param stop_after (float): Time budget (in seconds) to train this mixer.
""" # noqa
super().__init__(stop_after)
self.target_encoder = target_encoder
self.supports_proba = False
self.stable = True

def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
log.info("Unit Mixer just borrows from encoder")
log.info("Unit mixer does not require training, it passes through predictions from its encoders.")

def partial_fit(self, train_data: EncodedDs, dev_data: EncodedDs, args: Optional[dict] = None) -> None:
pass

def __call__(self, ds: EncodedDs,
args: PredictionArguments = PredictionArguments()) -> pd.DataFrame:
"""
Makes predictions using the provided EncodedDs dataset.
Mixer decodes predictions using the target encoder and returns them in a pandas DataFrame.
:returns ydf (pd.DataFrame): a data frame containing the decoded predictions.
"""
if args.predict_proba:
# @TODO: depending on the target encoder, this might be enabled
log.warning('This model does not output probability estimates')
Expand Down

0 comments on commit 9b59d6b

Please sign in to comment.