Skip to content

Commit 0064ace

Browse files
authored
Feature/lightning recommend (#256)
Moved recommend to lightning module: helps introduce custom tying of user and item embeddings Introduced item net constructor type
1 parent 823377d commit 0064ace

File tree

9 files changed

+782
-614
lines changed

9 files changed

+782
-614
lines changed

rectools/models/nn/bert4rec.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@
2020
import torch
2121

2222
from .constants import MASKING_VALUE, PADDING_VALUE
23-
from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase
23+
from .item_net import (
24+
CatFeaturesItemNet,
25+
IdEmbeddingsItemNet,
26+
ItemNetBase,
27+
ItemNetConstructorBase,
28+
SumOfEmbeddingsConstructor,
29+
)
2430
from .transformer_base import (
2531
TrainerCallable,
2632
TransformerDataPreparatorType,
@@ -234,6 +240,8 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
234240
(IdEmbeddingsItemNet,) - item embeddings based on ids.
235241
(CatFeaturesItemNet,) - item embeddings based on categorical features.
236242
(IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features.
243+
item_net_constructor_type : type(ItemNetConstructorBase), default `SumOfEmbeddingsConstructor`
244+
Type of item net blocks aggregation constructor.
237245
pos_encoding_type : type(PositionalEncodingBase), default `LearnableInversePositionalEncoding`
238246
Type of positional encoding.
239247
transformer_layers_type : type(TransformerLayersBase), default `PreLNTransformerLayers`
@@ -255,16 +263,9 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
255263
How many samples per batch to load during `recommend`.
256264
If you want to change this parameter after model is initialized,
257265
you can manually assign new value to model `recommend_batch_size` attribute.
258-
recommend_accelerator : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"}, default "auto"
259-
Accelerator type for `recommend`. Used at predict_step of lightning module.
260-
If you want to change this parameter after model is initialized,
261-
you can manually assign new value to model `recommend_accelerator` attribute.
262-
recommend_devices : int | List[int], default 1
263-
Devices for `recommend`. Please note that multi-device inference is not supported!
264-
Do not specify more then one device. For ``gpu`` accelerator you can pass which device to
265-
use, e.g. ``[1]``.
266-
Used at predict_step of lightning module.
267-
Multi-device recommendations are not supported.
266+
recommend_device : {"cpu", "cuda", "cuda:0", ...}, default ``None``
267+
String representation for `torch.device` used for recommendations.
268+
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
268269
If you want to change this parameter after model is initialized,
269270
you can manually assign new value to model `recommend_device` attribute.
270271
recommend_n_threads : int, default 0
@@ -301,17 +302,17 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
301302
use_key_padding_mask: bool = True,
302303
use_causal_attn: bool = False,
303304
item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet),
305+
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
304306
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
305307
transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers,
306308
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = BERT4RecDataPreparator,
307309
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
308310
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
309311
get_trainer_func: tp.Optional[TrainerCallable] = None,
310312
recommend_batch_size: int = 256,
311-
recommend_accelerator: str = "auto",
312-
recommend_devices: tp.Union[int, tp.List[int]] = 1,
313+
recommend_device: tp.Optional[str] = None,
313314
recommend_n_threads: int = 0,
314-
recommend_use_gpu_ranking: bool = True,
315+
recommend_use_gpu_ranking: bool = True, # TODO: remove after TorchRanker
315316
):
316317
self.mask_prob = mask_prob
317318

@@ -336,12 +337,12 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
336337
verbose=verbose,
337338
deterministic=deterministic,
338339
recommend_batch_size=recommend_batch_size,
339-
recommend_accelerator=recommend_accelerator,
340-
recommend_devices=recommend_devices,
340+
recommend_device=recommend_device,
341341
recommend_n_threads=recommend_n_threads,
342342
recommend_use_gpu_ranking=recommend_use_gpu_ranking,
343343
train_min_user_interactions=train_min_user_interactions,
344344
item_net_block_types=item_net_block_types,
345+
item_net_constructor_type=item_net_constructor_type,
345346
pos_encoding_type=pos_encoding_type,
346347
lightning_module_type=lightning_module_type,
347348
get_val_mask_func=get_val_mask_func,

rectools/models/nn/item_net.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def from_dataset_schema(cls, dataset_schema: DatasetSchema, n_factors: int, drop
231231
return cls(n_factors, n_items, dropout_rate)
232232

233233

234-
class ItemNetConstructor(ItemNetBase):
234+
class ItemNetConstructorBase(ItemNetBase):
235235
"""
236236
Constructed network for item embeddings based on aggregation of embeddings from transferred item network types.
237237
@@ -257,26 +257,6 @@ def __init__(
257257
self.n_item_blocks = len(item_net_blocks)
258258
self.item_net_blocks = nn.ModuleList(item_net_blocks)
259259

260-
def forward(self, items: torch.Tensor) -> torch.Tensor:
261-
"""
262-
Forward pass to get item embeddings from item network blocks.
263-
264-
Parameters
265-
----------
266-
items : torch.Tensor
267-
Internal item ids.
268-
269-
Returns
270-
-------
271-
torch.Tensor
272-
Item embeddings.
273-
"""
274-
item_embs = []
275-
for idx_block in range(self.n_item_blocks):
276-
item_emb = self.item_net_blocks[idx_block](items)
277-
item_embs.append(item_emb)
278-
return torch.sum(torch.stack(item_embs, dim=0), dim=0)
279-
280260
@property
281261
def catalog(self) -> torch.Tensor:
282262
"""Return tensor with elements in range [0, n_items)."""
@@ -336,3 +316,52 @@ def from_dataset_schema(
336316
item_net_blocks.append(item_net_block)
337317

338318
return cls(n_items, item_net_blocks)
319+
320+
def forward(self, items: torch.Tensor) -> torch.Tensor:
321+
"""Forward pass through item net blocks and aggregation of the results.
322+
323+
Parameters
324+
----------
325+
items : torch.Tensor
326+
Internal item ids.
327+
328+
Returns
329+
-------
330+
torch.Tensor
331+
Item embeddings.
332+
"""
333+
raise NotImplementedError()
334+
335+
336+
class SumOfEmbeddingsConstructor(ItemNetConstructorBase):
337+
"""
338+
Item net blocks constructor that simply sums all of the its net blocks embeddings.
339+
340+
Parameters
341+
----------
342+
n_items : int
343+
Number of items in the dataset.
344+
item_net_blocks : Sequence(ItemNetBase)
345+
Latent embedding size of item embeddings.
346+
"""
347+
348+
def forward(self, items: torch.Tensor) -> torch.Tensor:
349+
"""
350+
Forward pass through item net blocks and aggregation of the results.
351+
Simple sum of embeddings.
352+
353+
Parameters
354+
----------
355+
items : torch.Tensor
356+
Internal item ids.
357+
358+
Returns
359+
-------
360+
torch.Tensor
361+
Item embeddings.
362+
"""
363+
item_embs = []
364+
for idx_block in range(self.n_item_blocks):
365+
item_emb = self.item_net_blocks[idx_block](items)
366+
item_embs.append(item_emb)
367+
return torch.sum(torch.stack(item_embs, dim=0), dim=0)

rectools/models/nn/sasrec.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
import torch
2020
from torch import nn
2121

22-
from .item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, ItemNetBase
22+
from .item_net import (
23+
CatFeaturesItemNet,
24+
IdEmbeddingsItemNet,
25+
ItemNetBase,
26+
ItemNetConstructorBase,
27+
SumOfEmbeddingsConstructor,
28+
)
2329
from .transformer_base import (
2430
TrainerCallable,
2531
TransformerDataPreparatorType,
@@ -263,6 +269,8 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
263269
(IdEmbeddingsItemNet,) - item embeddings based on ids.
264270
(CatFeaturesItemNet,) - item embeddings based on categorical features.
265271
(IdEmbeddingsItemNet, CatFeaturesItemNet) - item embeddings based on ids and categorical features.
272+
item_net_constructor_type : type(ItemNetConstructorBase), default `SumOfEmbeddingsConstructor`
273+
Type of item net blocks aggregation constructor.
266274
pos_encoding_type : type(PositionalEncodingBase), default `LearnableInversePositionalEncoding`
267275
Type of positional encoding.
268276
transformer_layers_type : type(TransformerLayersBase), default `SasRecTransformerLayers`
@@ -284,16 +292,9 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
284292
How many samples per batch to load during `recommend`.
285293
If you want to change this parameter after model is initialized,
286294
you can manually assign new value to model `recommend_batch_size` attribute.
287-
recommend_accelerator : {"cpu", "gpu", "tpu", "hpu", "mps", "auto"}, default "auto"
288-
Accelerator type for `recommend`. Used at predict_step of lightning module.
289-
If you want to change this parameter after model is initialized,
290-
you can manually assign new value to model `recommend_accelerator` attribute.
291-
recommend_devices : int | List[int], default 1
292-
Devices for `recommend`. Please note that multi-device inference is not supported!
293-
Do not specify more then one device. For ``gpu`` accelerator you can pass which device to
294-
use, e.g. ``[1]``.
295-
Used at predict_step of lightning module.
296-
Multi-device recommendations are not supported.
295+
recommend_device : {"cpu", "cuda", "cuda:0", ...}, default ``None``
296+
String representation for `torch.device` used for recommendations.
297+
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
297298
If you want to change this parameter after model is initialized,
298299
you can manually assign new value to model `recommend_device` attribute.
299300
recommend_n_threads : int, default 0
@@ -329,17 +330,17 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
329330
use_key_padding_mask: bool = False,
330331
use_causal_attn: bool = True,
331332
item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet),
333+
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
332334
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
333335
transformer_layers_type: tp.Type[TransformerLayersBase] = SASRecTransformerLayers, # SASRec authors net
334336
data_preparator_type: tp.Type[TransformerDataPreparatorBase] = SASRecDataPreparator,
335337
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
336338
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
337339
get_trainer_func: tp.Optional[TrainerCallable] = None,
338340
recommend_batch_size: int = 256,
339-
recommend_accelerator: str = "auto",
340-
recommend_devices: tp.Union[int, tp.List[int]] = 1,
341+
recommend_device: tp.Optional[str] = None,
341342
recommend_n_threads: int = 0,
342-
recommend_use_gpu_ranking: bool = True,
343+
recommend_use_gpu_ranking: bool = True, # TODO: remove after TorchRanker
343344
):
344345
super().__init__(
345346
transformer_layers_type=transformer_layers_type,
@@ -362,12 +363,12 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
362363
verbose=verbose,
363364
deterministic=deterministic,
364365
recommend_batch_size=recommend_batch_size,
365-
recommend_accelerator=recommend_accelerator,
366-
recommend_devices=recommend_devices,
366+
recommend_device=recommend_device,
367367
recommend_n_threads=recommend_n_threads,
368368
recommend_use_gpu_ranking=recommend_use_gpu_ranking,
369369
train_min_user_interactions=train_min_user_interactions,
370370
item_net_block_types=item_net_block_types,
371+
item_net_constructor_type=item_net_constructor_type,
371372
pos_encoding_type=pos_encoding_type,
372373
lightning_module_type=lightning_module_type,
373374
get_val_mask_func=get_val_mask_func,

0 commit comments

Comments
 (0)