2020import torch
2121
2222from .constants import MASKING_VALUE , PADDING_VALUE
23- from .item_net import CatFeaturesItemNet , IdEmbeddingsItemNet , ItemNetBase
23+ from .item_net import (
24+ CatFeaturesItemNet ,
25+ IdEmbeddingsItemNet ,
26+ ItemNetBase ,
27+ ItemNetConstructorBase ,
28+ SumOfEmbeddingsConstructor ,
29+ )
2430from .transformer_base import (
2531 TrainerCallable ,
2632 TransformerDataPreparatorType ,
@@ -234,6 +240,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
234240 (IdEmbeddingsItemNet,) - item embeddings based on ids.
235241 (CatFeaturesItemNet,) - item embeddings based on categorical features.
236242 (IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features.
243+ item_net_constructor_type : type(ItemNetConstructorBase), default `SumOfEmbeddingsConstructor`
244+ Type of item net blocks aggregation constructor.
237245 pos_encoding_type : type(PositionalEncodingBase), default `LearnableInversePositionalEncoding`
238246 Type of positional encoding.
239247 transformer_layers_type : type(TransformerLayersBase), default `PreLNTransformerLayers`
@@ -255,16 +263,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
255263 How many samples per batch to load during `recommend`.
256264 If you want to change this parameter after model is initialized,
257265 you can manually assign new value to model `recommend_batch_size` attribute.
258- recommend_accelerator : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"}, default "auto"
259- Accelerator type for `recommend`. Used at predict_step of lightning module.
260- If you want to change this parameter after model is initialized,
261- you can manually assign new value to model `recommend_accelerator` attribute.
262- recommend_devices : int | List[int], default 1
263- Devices for `recommend`. Please note that multi-device inference is not supported!
264- Do not specify more then one device. For ``gpu`` accelerator you can pass which device to
265- use, e.g. ``[1]``.
266- Used at predict_step of lightning module.
267- Multi-device recommendations are not supported.
266+ recommend_device : {"cpu", "cuda", "cuda:0", ...}, default ``None``
267+ String representation for `torch.device` used for recommendations.
268+ When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
268269 If you want to change this parameter after model is initialized,
269270 you can manually assign new value to model `recommend_device` attribute.
270271 recommend_n_threads : int, default 0
@@ -301,17 +302,17 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
301302 use_key_padding_mask : bool = True ,
302303 use_causal_attn : bool = False ,
303304 item_net_block_types : tp .Sequence [tp .Type [ItemNetBase ]] = (IdEmbeddingsItemNet , CatFeaturesItemNet ),
305+ item_net_constructor_type : tp .Type [ItemNetConstructorBase ] = SumOfEmbeddingsConstructor ,
304306 pos_encoding_type : tp .Type [PositionalEncodingBase ] = LearnableInversePositionalEncoding ,
305307 transformer_layers_type : tp .Type [TransformerLayersBase ] = PreLNTransformerLayers ,
306308 data_preparator_type : tp .Type [TransformerDataPreparatorBase ] = BERT4RecDataPreparator ,
307309 lightning_module_type : tp .Type [TransformerLightningModuleBase ] = TransformerLightningModule ,
308310 get_val_mask_func : tp .Optional [ValMaskCallable ] = None ,
309311 get_trainer_func : tp .Optional [TrainerCallable ] = None ,
310312 recommend_batch_size : int = 256 ,
311- recommend_accelerator : str = "auto" ,
312- recommend_devices : tp .Union [int , tp .List [int ]] = 1 ,
313+ recommend_device : tp .Optional [str ] = None ,
313314 recommend_n_threads : int = 0 ,
314- recommend_use_gpu_ranking : bool = True ,
315+ recommend_use_gpu_ranking : bool = True , # TODO: remove after TorchRanker
315316 ):
316317 self .mask_prob = mask_prob
317318
@@ -336,12 +337,12 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
336337 verbose = verbose ,
337338 deterministic = deterministic ,
338339 recommend_batch_size = recommend_batch_size ,
339- recommend_accelerator = recommend_accelerator ,
340- recommend_devices = recommend_devices ,
340+ recommend_device = recommend_device ,
341341 recommend_n_threads = recommend_n_threads ,
342342 recommend_use_gpu_ranking = recommend_use_gpu_ranking ,
343343 train_min_user_interactions = train_min_user_interactions ,
344344 item_net_block_types = item_net_block_types ,
345+ item_net_constructor_type = item_net_constructor_type ,
345346 pos_encoding_type = pos_encoding_type ,
346347 lightning_module_type = lightning_module_type ,
347348 get_val_mask_func = get_val_mask_func ,
0 commit comments