Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
]


ValMaskCallable = Callable[[], np.ndarray]
ValMaskCallable = Callable[..., np.ndarray]

ValMaskCallableSerialized = tpe.Annotated[
ValMaskCallable,
Expand All @@ -173,7 +173,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
),
]

TrainerCallable = Callable[[], Trainer]
TrainerCallable = Callable[..., Trainer]

TrainerCallableSerialized = tpe.Annotated[
TrainerCallable,
Expand Down Expand Up @@ -220,6 +220,8 @@ class TransformerModelConfig(ModelConfig):
backbone_type: TransformerBackboneType = TransformerTorchBackbone
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None
data_preparator_kwargs: tp.Optional[InitKwargs] = None
transformer_layers_kwargs: tp.Optional[InitKwargs] = None
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None
Expand Down Expand Up @@ -280,6 +282,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
transformer_layers_kwargs: tp.Optional[InitKwargs] = None,
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
Expand Down Expand Up @@ -321,6 +325,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.backbone_type = backbone_type
self.get_val_mask_func = get_val_mask_func
self.get_trainer_func = get_trainer_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
self.get_trainer_func_kwargs = get_trainer_func_kwargs
self.data_preparator_kwargs = data_preparator_kwargs
self.transformer_layers_kwargs = transformer_layers_kwargs
self.item_net_constructor_kwargs = item_net_constructor_kwargs
Expand Down Expand Up @@ -354,6 +360,7 @@ def _init_data_preparator(self) -> None:
negative_sampler=self._init_negative_sampler() if requires_negatives else None,
n_negatives=self.n_negatives if requires_negatives else None,
get_val_mask_func=self.get_val_mask_func,
get_val_mask_func_kwargs=self.get_val_mask_func_kwargs,
shuffle_train=True,
**self._get_kwargs(self.data_preparator_kwargs),
)
Expand All @@ -371,7 +378,7 @@ def _init_trainer(self) -> None:
devices=1,
)
else:
self._trainer = self.get_trainer_func()
self._trainer = self.get_trainer_func(**self._get_kwargs(self.get_trainer_func_kwargs))

def _init_negative_sampler(self) -> TransformerNegativeSamplerBase:
return self.negative_sampler_type(
Expand Down
13 changes: 13 additions & 0 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class BERT4RecDataPreparator(TransformerDataPreparatorBase):
Negative sampler.
mask_prob : float, default 0.15
Probability of masking an item in interactions sequence.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional arguments for the get_val_mask_func.
"""

train_session_max_len_addition: int = 0
Expand All @@ -88,6 +90,7 @@ def __init__(
mask_prob: float = 0.15,
shuffle_train: bool = True,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
super().__init__(
Expand All @@ -99,6 +102,7 @@ def __init__(
train_min_user_interactions=train_min_user_interactions,
shuffle_train=shuffle_train,
get_val_mask_func=get_val_mask_func,
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
)
self.mask_prob = mask_prob

Expand Down Expand Up @@ -301,6 +305,10 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_torch_device` attribute.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_val_mask_func.
get_trainer_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_trainer_func.
data_preparator_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `data_preparator_type` initialization.
Make sure all dict values have JSON serializable types.
Expand Down Expand Up @@ -361,6 +369,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
recommend_batch_size: int = 256,
recommend_torch_device: tp.Optional[str] = None,
recommend_use_torch_ranking: bool = True,
Expand Down Expand Up @@ -411,6 +421,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type=backbone_type,
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
get_trainer_func_kwargs=get_trainer_func_kwargs,
data_preparator_kwargs=data_preparator_kwargs,
transformer_layers_kwargs=transformer_layers_kwargs,
item_net_block_kwargs=item_net_block_kwargs,
Expand All @@ -433,6 +445,7 @@ def _init_data_preparator(self) -> None:
train_min_user_interactions=self.train_min_user_interactions,
mask_prob=self.mask_prob,
get_val_mask_func=self.get_val_mask_func,
get_val_mask_func_kwargs= self.get_val_mask_func_kwargs,
shuffle_train=True,
**self._get_kwargs(self.data_preparator_kwargs),
)
15 changes: 14 additions & 1 deletion rectools/models/nn/transformers/data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from .constants import PADDING_VALUE
from .negative_sampler import TransformerNegativeSamplerBase

InitKwargs = tp.Dict[str, tp.Any]


class SequenceDataset(TorchDataset):
"""
Expand Down Expand Up @@ -109,6 +111,8 @@ class TransformerDataPreparatorBase:
Number of negatives for BCE, gBCE and sampled_softmax losses.
negative_sampler: optional(TransformerNegativeSamplerBase), default ``None``
Negative sampler.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_val_mask_func.
"""

# We sometimes need data preparators to add +1 to actual session_max_len
Expand All @@ -127,6 +131,7 @@ def __init__(
get_val_mask_func: tp.Optional[tp.Callable] = None,
n_negatives: tp.Optional[int] = None,
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
self.item_id_map: IdMap
Expand All @@ -141,6 +146,7 @@ def __init__(
self.train_min_user_interactions = train_min_user_interactions
self.shuffle_train = shuffle_train
self.get_val_mask_func = get_val_mask_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs

def get_known_items_sorted_internal_ids(self) -> np.ndarray:
"""Return internal item ids from processed dataset in sorted order."""
Expand All @@ -150,6 +156,13 @@ def get_known_item_ids(self) -> np.ndarray:
"""Return external item ids from processed dataset in sorted order."""
return self.item_id_map.get_external_sorted_by_internal()[self.n_item_extra_tokens :]

@staticmethod
def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
kwargs = {}
if actual_kwargs is not None:
kwargs = actual_kwargs
return kwargs

@property
def n_item_extra_tokens(self) -> int:
"""Return number of padding elements"""
Expand Down Expand Up @@ -194,7 +207,7 @@ def process_dataset_train(self, dataset: Dataset) -> None:
# Exclude val interaction targets from train if needed
interactions = raw_interactions
if self.get_val_mask_func is not None:
val_mask = self.get_val_mask_func(raw_interactions)
val_mask = self.get_val_mask_func(raw_interactions, **self._get_kwargs(self.get_val_mask_func_kwargs))
interactions = raw_interactions[~val_mask]
interactions.reset_index(drop=True, inplace=True)

Expand Down
10 changes: 10 additions & 0 deletions rectools/models/nn/transformers/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class SASRecDataPreparator(TransformerDataPreparatorBase):
Number of negatives for BCE, gBCE and sampled_softmax losses.
negative_sampler: optional(TransformerNegativeSamplerBase), default ``None``
Negative sampler.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional arguments for the get_val_mask_func.
"""

train_session_max_len_addition: int = 1
Expand Down Expand Up @@ -379,6 +381,10 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_torch_device` attribute.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_val_mask_func.
get_trainer_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_trainer_func.
data_preparator_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `data_preparator_type` initialization.
Make sure all dict values have JSON serializable types.
Expand Down Expand Up @@ -438,6 +444,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
recommend_batch_size: int = 256,
recommend_torch_device: tp.Optional[str] = None,
recommend_use_torch_ranking: bool = True,
Expand Down Expand Up @@ -485,6 +493,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type=backbone_type,
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
get_val_mask_func_kwargs = get_val_mask_func_kwargs,
get_trainer_func_kwargs = get_trainer_func_kwargs,
data_preparator_kwargs=data_preparator_kwargs,
transformer_layers_kwargs=transformer_layers_kwargs,
item_net_constructor_kwargs=item_net_constructor_kwargs,
Expand Down
94 changes: 92 additions & 2 deletions tests/models/nn/transformers/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import typing as tp
from functools import partial

Expand All @@ -31,6 +32,7 @@
PreLNTransformerLayers,
TrainerCallable,
TransformerLightningModule,
ValMaskCallable,
)
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
Expand All @@ -44,6 +46,8 @@

from .utils import custom_trainer, leave_one_out_mask

InitKwargs = tp.Dict[str, tp.Any]


class TestBERT4RecModel:
def setup_method(self) -> None:
Expand Down Expand Up @@ -114,6 +118,14 @@ def get_trainer() -> Trainer:

return get_trainer

@pytest.fixture
def factory_get_trainer_func(self) -> TrainerCallable:
def get_trainer_with_kwargs(**get_trainer_kwargs) -> Trainer:
get_trainer_kwargs = get_trainer_kwargs or {}
return Trainer(**get_trainer_kwargs)

return get_trainer_with_kwargs

@pytest.mark.parametrize(
"accelerator,n_devices,recommend_torch_device",
[
Expand Down Expand Up @@ -310,14 +322,29 @@ def get_trainer() -> Trainer:
),
)
@pytest.mark.parametrize("u2i_dist", ("dot", "cosine"))
@pytest.mark.parametrize(
"get_trainer_func_kwargs",
(
{
"max_epochs": 2,
"min_epochs": 2,
"deterministic": True,
"accelerator": "cpu",
"enable_checkpointing": False,
"devices": 1,
},
),
)
def test_u2i_losses(
self,
dataset_devices: Dataset,
loss: str,
get_trainer_func: TrainerCallable,
factory_get_trainer_func: TrainerCallable,
expected: pd.DataFrame,
u2i_dist: str,
get_trainer_func_kwargs: InitKwargs,
) -> None:
assert set(get_trainer_func_kwargs.keys()).issubset(inspect.signature(Trainer.__init__).parameters.keys())
model = BERT4RecModel(
n_negatives=2,
n_factors=32,
Expand All @@ -330,7 +357,8 @@ def test_u2i_losses(
deterministic=True,
mask_prob=0.6,
item_net_block_types=(IdEmbeddingsItemNet,),
get_trainer_func=get_trainer_func,
get_trainer_func=factory_get_trainer_func,
get_trainer_func_kwargs=get_trainer_func_kwargs,
loss=loss,
similarity_module_type=DistanceSimilarityModule,
similarity_module_kwargs={"distance": u2i_dist},
Expand Down Expand Up @@ -561,6 +589,7 @@ def __init__(
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
shuffle_train: bool = True,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
n_last_targets: int = 1, # custom kwarg
) -> None:
super().__init__(
Expand All @@ -572,6 +601,7 @@ def __init__(
negative_sampler=negative_sampler,
shuffle_train=shuffle_train,
get_val_mask_func=get_val_mask_func,
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
mask_prob=mask_prob,
)
self.n_last_targets = n_last_targets
Expand Down Expand Up @@ -728,6 +758,35 @@ def get_val_mask(interactions: pd.DataFrame, val_users: ExternalIds) -> np.ndarr
get_val_mask_func=get_val_mask_func,
)

@pytest.fixture
def factory_data_preparator_val_mask_with_kwargs(self) -> ValMaskCallable:
def data_preparator_val_mask_with_kwargs(get_val_mask_func_kwargs) -> BERT4RecDataPreparator:
def get_val_mask(interactions: pd.DataFrame, **kwargs) -> np.ndarray:
val_users = kwargs.get("val_users")
rank = (
interactions.sort_values(Columns.Datetime, ascending=False, kind="stable")
.groupby(Columns.User, sort=False)
.cumcount()
+ 1
)
val_mask = (interactions[Columns.User].isin(val_users)) & (rank <= 1)
return val_mask.values

assert "val_users" in get_val_mask_func_kwargs

return BERT4RecDataPreparator(
session_max_len=4,
n_negatives=2,
train_min_user_interactions=2,
mask_prob=0.5,
batch_size=4,
dataloader_num_workers=0,
get_val_mask_func=get_val_mask,
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
)

return data_preparator_val_mask_with_kwargs

@pytest.mark.parametrize(
"train_batch",
(
Expand Down Expand Up @@ -816,6 +875,35 @@ def test_get_dataloader_val(
for key, value in actual.items():
assert torch.equal(value, val_batch[key])

@pytest.mark.parametrize("val_users", ([10, 30],))
@pytest.mark.parametrize(
"val_batch",
(
(
{
"x": torch.tensor([[0, 2, 4, 1]]),
"y": torch.tensor([[3]]),
"yw": torch.tensor([[1.0]]),
"negatives": torch.tensor([[[5, 2]]]),
}
),
),
)
def test_get_dataloader_val_with_kwargs(
self,
dataset: Dataset,
factory_data_preparator_val_mask_with_kwargs,
val_users: tp.Dict[tp.Any, tp.Any],
val_batch: tp.List,
) -> None:
kwargs = {"val_users": val_users}
data_preparator_val_mask = factory_data_preparator_val_mask_with_kwargs(get_val_mask_func_kwargs=kwargs)
data_preparator_val_mask.process_dataset_train(dataset)
dataloader = data_preparator_val_mask.get_dataloader_val()
actual = next(iter(dataloader)) # type: ignore
for key, value in actual.items():
assert torch.equal(value, val_batch[key])


class TestBERT4RecModelConfiguration:
def setup_method(self) -> None:
Expand Down Expand Up @@ -860,6 +948,8 @@ def initial_config(self) -> tp.Dict[str, tp.Any]:
"mask_prob": 0.15,
"get_val_mask_func": leave_one_out_mask,
"get_trainer_func": None,
"get_val_mask_func_kwargs": None,
"get_trainer_func_kwargs": None,
"data_preparator_kwargs": None,
"transformer_layers_kwargs": None,
"item_net_constructor_kwargs": None,
Expand Down
Loading
Loading