From 660594cb8b40c785391c06399af20e119c251f8d Mon Sep 17 00:00:00 2001 From: spirinamayya Date: Thu, 10 Apr 2025 13:36:31 +0300 Subject: [PATCH] fix negative sampler kwargs config --- rectools/models/nn/transformers/base.py | 1 + tests/models/nn/transformers/test_bert4rec.py | 1 + tests/models/nn/transformers/test_sasrec.py | 1 + 3 files changed, 3 insertions(+) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 86899394..75029699 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -225,6 +225,7 @@ class TransformerModelConfig(ModelConfig): item_net_constructor_kwargs: tp.Optional[InitKwargs] = None pos_encoding_kwargs: tp.Optional[InitKwargs] = None lightning_module_kwargs: tp.Optional[InitKwargs] = None + negative_sampler_kwargs: tp.Optional[InitKwargs] = None similarity_module_kwargs: tp.Optional[InitKwargs] = None backbone_kwargs: tp.Optional[InitKwargs] = None diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 9b0c10db..e1c9fb4b 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -865,6 +865,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, + "negative_sampler_kwargs": None, "similarity_module_kwargs": None, "backbone_kwargs": None, } diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 4f855da7..24438cc4 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -973,6 +973,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "item_net_constructor_kwargs": None, "pos_encoding_kwargs": None, "lightning_module_kwargs": None, + "negative_sampler_kwargs": None, "similarity_module_kwargs": None, "backbone_kwargs": None, }