diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 86899394..75029699 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -225,6 +225,7 @@ class TransformerModelConfig(ModelConfig): item_net_constructor_kwargs: tp.Optional[InitKwargs] = None pos_encoding_kwargs: tp.Optional[InitKwargs] = None lightning_module_kwargs: tp.Optional[InitKwargs] = None + negative_sampler_kwargs: tp.Optional[InitKwargs] = None similarity_module_kwargs: tp.Optional[InitKwargs] = None backbone_kwargs: tp.Optional[InitKwargs] = None diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 9b0c10db..e1c9fb4b 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -865,6 +865,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, + "negative_sampler_kwargs": None, "similarity_module_kwargs": None, "backbone_kwargs": None, } diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 4f855da7..24438cc4 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -973,6 +973,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, + "negative_sampler_kwargs": None, "similarity_module_kwargs": None, "backbone_kwargs": None, }