Skip to content

Commit 43321f2

Browse files
committed
device fix and docs
1 parent 0ed006b commit 43321f2

File tree

6 files changed

+51
-41
lines changed

6 files changed

+51
-41
lines changed

rectools/models/nn/transformers/base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class TransformerModelConfig(ModelConfig):
176176
verbose: int = 0
177177
deterministic: bool = False
178178
recommend_batch_size: int = 256
179-
recommend_device: tp.Optional[str] = None
179+
recommend_torch_device: tp.Optional[str] = None
180180
recommend_n_threads: int = 0
181181
recommend_use_torch_ranking: bool = True
182182
train_min_user_interactions: int = 2
@@ -233,7 +233,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
233233
verbose: int = 0,
234234
deterministic: bool = False,
235235
recommend_batch_size: int = 256,
236-
recommend_device: tp.Optional[str] = None,
236+
recommend_torch_device: tp.Optional[str] = None,
237237
recommend_n_threads: int = 0,
238238
recommend_use_torch_ranking: bool = True,
239239
train_min_user_interactions: int = 2,
@@ -270,7 +270,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
270270
self.epochs = epochs
271271
self.deterministic = deterministic
272272
self.recommend_batch_size = recommend_batch_size
273-
self.recommend_device = recommend_device
273+
self.recommend_torch_device = recommend_torch_device
274274
self.recommend_n_threads = recommend_n_threads
275275
self.recommend_use_torch_ranking = recommend_use_torch_ranking
276276
self.train_min_user_interactions = train_min_user_interactions
@@ -454,7 +454,7 @@ def _recommend_u2i(
454454
dataset=dataset,
455455
n_threads=self.recommend_n_threads,
456456
use_torch_ranking=self.recommend_use_torch_ranking,
457-
device=self.recommend_device,
457+
torch_device=self.recommend_torch_device,
458458
)
459459

460460
def _recommend_i2i(
@@ -473,7 +473,7 @@ def _recommend_i2i(
473473
k=k,
474474
n_threads=self.recommend_n_threads,
475475
use_torch_ranking=self.recommend_use_torch_ranking,
476-
device=self.recommend_device,
476+
torch_device=self.recommend_torch_device,
477477
)
478478

479479
@property

rectools/models/nn/transformers/bert4rec.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,19 +269,24 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
269269
How many samples per batch to load during `recommend`.
270270
If you want to change this parameter after model is initialized,
271271
you can manually assign new value to model `recommend_batch_size` attribute.
272-
recommend_device : {"cpu", "cuda", "cuda:0", ...}, default ``None``
273-
String representation for `torch.device` used for recommendations.
272+
recommend_torch_device : {"cpu", "cuda", "cuda:0", ...}, default ``None``
273+
String representation for `torch.device` used for torch model inference.
274+
When `recommend_use_torch_ranking` is set to ``True`` (default) this device is also used
275+
for items ranking while preparing recommendations using `TorchRanker`.
274276
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
275277
If you want to change this parameter after model is initialized,
276-
you can manually assign new value to model `recommend_device` attribute.
278+
you can manually assign new value to model `recommend_torch_device` attribute.
277279
recommend_use_torch_ranking : bool, default ``True``
278280
Use `TorchRanker` for items ranking while preparing recommendations.
279-
If set to ``False``, use `ImplicitRanker` instead.
281+
When set to ``True`` (default), device specified in `recommend_torch_device` is used
282+
for items ranking.
283+
When set to ``False``, multi-threaded cpu ranking will be used with `ImplicitRanker`. You
284+
can specify numer of threads using `recommend_n_threads` argument.
280285
If you want to change this parameter after model is initialized,
281286
you can manually assign new value to model `recommend_use_torch_ranking` attribute.
282287
recommend_n_threads : int, default 0
283-
Number of threads to use for `ImplicitRanker`. Omitted if `recommend_use_torch_ranking` is
284-
set to ``True`` (default).
288+
Number of threads to use for cpu items ranking with `ImplicitRanker`. Omitted if
289+
`recommend_use_torch_ranking` is set to ``True`` (default).
285290
If you want to change this parameter after model is initialized,
286291
you can manually assign new value to model `recommend_n_threads` attribute.
287292
data_preparator_kwargs: optional(dict), default ``None``
@@ -333,7 +338,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
333338
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
334339
get_trainer_func: tp.Optional[TrainerCallable] = None,
335340
recommend_batch_size: int = 256,
336-
recommend_device: tp.Optional[str] = None,
341+
recommend_torch_device: tp.Optional[str] = None,
337342
recommend_use_torch_ranking: bool = True,
338343
recommend_n_threads: int = 0,
339344
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
@@ -366,7 +371,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
366371
verbose=verbose,
367372
deterministic=deterministic,
368373
recommend_batch_size=recommend_batch_size,
369-
recommend_device=recommend_device,
374+
recommend_torch_device=recommend_torch_device,
370375
recommend_n_threads=recommend_n_threads,
371376
recommend_use_torch_ranking=recommend_use_torch_ranking,
372377
train_min_user_interactions=train_min_user_interactions,

rectools/models/nn/transformers/lightning.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _recommend_u2i(
115115
filter_viewed: bool,
116116
use_torch_ranking: bool,
117117
n_threads: int,
118-
device: tp.Optional[str],
118+
torch_device: tp.Optional[str],
119119
*args: tp.Any,
120120
**kwargs: tp.Any,
121121
) -> InternalRecoTriplet:
@@ -129,7 +129,7 @@ def _recommend_i2i(
129129
k: int,
130130
use_torch_ranking: bool,
131131
n_threads: int,
132-
device: tp.Optional[str],
132+
torch_device: tp.Optional[str],
133133
*args: tp.Any,
134134
**kwargs: tp.Any,
135135
) -> InternalRecoTriplet:
@@ -295,23 +295,23 @@ def _xavier_normal_init(self) -> None:
295295
if param.data.dim() > 1:
296296
torch.nn.init.xavier_normal_(param.data)
297297

298-
def _prepare_for_inference(self, recommend_device: tp.Optional[str]) -> None:
299-
if recommend_device is None:
300-
recommend_device = "cuda" if torch.cuda.is_available() else "cpu"
301-
device = torch.device(recommend_device)
298+
def _prepare_for_inference(self, torch_device: tp.Optional[str]) -> None:
299+
if torch_device is None:
300+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
301+
device = torch.device(torch_device)
302302
self.torch_model.to(device)
303303
self.torch_model.eval()
304304

305305
def _get_user_item_embeddings(
306306
self,
307307
recommend_dataloader: DataLoader,
308-
recommend_device: tp.Optional[str],
308+
torch_device: tp.Optional[str],
309309
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
310310
"""
311311
Prepare user embeddings for all user interaction sequences in `recommend_dataloader`.
312312
Prepare item embeddings for full items catalog.
313313
"""
314-
self._prepare_for_inference(recommend_device)
314+
self._prepare_for_inference(torch_device)
315315
device = self.torch_model.item_model.device
316316

317317
with torch.no_grad():
@@ -333,14 +333,14 @@ def _recommend_u2i(
333333
filter_viewed: bool,
334334
use_torch_ranking: bool,
335335
n_threads: int,
336-
device: tp.Optional[str],
336+
torch_device: tp.Optional[str],
337337
) -> InternalRecoTriplet:
338338
"""Recommend to users."""
339339
ui_csr_for_filter = None
340340
if filter_viewed:
341341
ui_csr_for_filter = dataset.get_user_item_matrix(include_weights=False, include_warm_items=True)[user_ids]
342342

343-
user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, device)
343+
user_embs, item_embs = self._get_user_item_embeddings(recommend_dataloader, torch_device)
344344

345345
ranker: Ranker
346346
if use_torch_ranking:
@@ -379,10 +379,10 @@ def _recommend_i2i(
379379
k: int,
380380
use_torch_ranking: bool,
381381
n_threads: int,
382-
device: tp.Optional[str],
382+
torch_device: tp.Optional[str],
383383
) -> InternalRecoTriplet:
384384
"""Recommend to items."""
385-
self._prepare_for_inference(device)
385+
self._prepare_for_inference(torch_device)
386386
with torch.no_grad():
387387
item_embs = self.torch_model.item_model.get_all_embeddings()
388388

rectools/models/nn/transformers/sasrec.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -349,19 +349,24 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
349349
How many samples per batch to load during `recommend`.
350350
If you want to change this parameter after model is initialized,
351351
you can manually assign new value to model `recommend_batch_size` attribute.
352-
recommend_device : {"cpu", "cuda", "cuda:0", ...}, default ``None``
353-
String representation for `torch.device` used for recommendations.
352+
recommend_torch_device : {"cpu", "cuda", "cuda:0", ...}, default ``None``
353+
String representation for `torch.device` used for torch model inference.
354+
When `recommend_use_torch_ranking` is set to ``True`` (default) this device is also used
355+
for items ranking while preparing recommendations using `TorchRanker`.
354356
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
355357
If you want to change this parameter after model is initialized,
356-
you can manually assign new value to model `recommend_device` attribute.
358+
you can manually assign new value to model `recommend_torch_device` attribute.
357359
recommend_use_torch_ranking : bool, default ``True``
358360
Use `TorchRanker` for items ranking while preparing recommendations.
359-
If set to ``False``, use `ImplicitRanker` instead.
361+
When set to ``True`` (default), device specified in `recommend_torch_device` is used
362+
for items ranking.
363+
When set to ``False``, multi-threaded cpu ranking will be used with `ImplicitRanker`. You
364+
can specify numer of threads using `recommend_n_threads` argument.
360365
If you want to change this parameter after model is initialized,
361366
you can manually assign new value to model `recommend_use_torch_ranking` attribute.
362367
recommend_n_threads : int, default 0
363-
Number of threads to use for `ImplicitRanker`. Omitted if `recommend_use_torch_ranking` is
364-
set to ``True`` (default).
368+
Number of threads to use for cpu items ranking with `ImplicitRanker`. Omitted if
369+
`recommend_use_torch_ranking` is set to ``True`` (default).
365370
If you want to change this parameter after model is initialized,
366371
you can manually assign new value to model `recommend_n_threads` attribute.
367372
data_preparator_kwargs: optional(dict), default ``None``
@@ -412,7 +417,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
412417
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
413418
get_trainer_func: tp.Optional[TrainerCallable] = None,
414419
recommend_batch_size: int = 256,
415-
recommend_device: tp.Optional[str] = None,
420+
recommend_torch_device: tp.Optional[str] = None,
416421
recommend_use_torch_ranking: bool = True,
417422
recommend_n_threads: int = 0,
418423
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
@@ -442,7 +447,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
442447
verbose=verbose,
443448
deterministic=deterministic,
444449
recommend_batch_size=recommend_batch_size,
445-
recommend_device=recommend_device,
450+
recommend_torch_device=recommend_torch_device,
446451
recommend_n_threads=recommend_n_threads,
447452
recommend_use_torch_ranking=recommend_use_torch_ranking,
448453
train_min_user_interactions=train_min_user_interactions,

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def get_trainer() -> Trainer:
113113

114114
@pytest.mark.parametrize("recommend_use_torch_ranking", (True, False))
115115
@pytest.mark.parametrize(
116-
"accelerator,n_devices,recommend_device",
116+
"accelerator,n_devices,recommend_torch_device",
117117
[
118118
("cpu", 1, "cpu"),
119119
pytest.param(
@@ -219,7 +219,7 @@ def test_u2i(
219219
filter_viewed: bool,
220220
accelerator: str,
221221
n_devices: int,
222-
recommend_device: str,
222+
recommend_torch_device: str,
223223
expected_cpu_1: pd.DataFrame,
224224
expected_cpu_2: pd.DataFrame,
225225
expected_gpu_1: pd.DataFrame,
@@ -248,7 +248,7 @@ def get_trainer() -> Trainer:
248248
batch_size=4,
249249
epochs=2,
250250
deterministic=True,
251-
recommend_device=recommend_device,
251+
recommend_torch_device=recommend_torch_device,
252252
item_net_block_types=(IdEmbeddingsItemNet,),
253253
get_trainer_func=get_trainer,
254254
recommend_use_torch_ranking=recommend_use_torch_ranking,
@@ -832,7 +832,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]:
832832
"epochs": 10,
833833
"verbose": 1,
834834
"deterministic": True,
835-
"recommend_device": None,
835+
"recommend_torch_device": None,
836836
"recommend_batch_size": 256,
837837
"recommend_n_threads": 0,
838838
"recommend_use_torch_ranking": True,

tests/models/nn/transformers/test_sasrec.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def get_trainer() -> Trainer:
158158

159159
@pytest.mark.parametrize("recommend_use_torch_ranking", (True, False))
160160
@pytest.mark.parametrize(
161-
"accelerator,devices,recommend_device",
161+
"accelerator,devices,recommend_torch_device",
162162
[
163163
("cpu", 1, "cpu"),
164164
pytest.param(
@@ -250,7 +250,7 @@ def test_u2i(
250250
filter_viewed: bool,
251251
accelerator: str,
252252
devices: tp.Union[int, tp.List[int]],
253-
recommend_device: str,
253+
recommend_torch_device: str,
254254
expected_cpu_1: pd.DataFrame,
255255
expected_cpu_2: pd.DataFrame,
256256
expected_gpu: pd.DataFrame,
@@ -279,7 +279,7 @@ def get_trainer() -> Trainer:
279279
batch_size=4,
280280
epochs=2,
281281
deterministic=True,
282-
recommend_device=recommend_device,
282+
recommend_torch_device=recommend_torch_device,
283283
item_net_block_types=(IdEmbeddingsItemNet,),
284284
get_trainer_func=get_trainer,
285285
recommend_use_torch_ranking=recommend_use_torch_ranking,
@@ -910,7 +910,7 @@ def initial_config(self) -> tp.Dict[str, tp.Any]:
910910
"epochs": 10,
911911
"verbose": 1,
912912
"deterministic": True,
913-
"recommend_device": None,
913+
"recommend_torch_device": None,
914914
"recommend_batch_size": 256,
915915
"recommend_n_threads": 0,
916916
"recommend_use_torch_ranking": True,

0 commit comments

Comments
 (0)