Skip to content

Commit 8666773

Browse files
committed
fixed ranking
1 parent 98908fa commit 8666773

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

rectools/models/nn/transformer_lightning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def recommend_u2i(
336336
Distance.DOT,
337337
user_embs_np[user_ids], # [n_rec_users, n_factors]
338338
item_embs_np, # [n_items + n_item_extra_tokens, n_factors]
339+
num_threads=recommend_n_threads,
340+
use_gpu=recommend_use_gpu_ranking and HAS_CUDA,
339341
)
340342

341343
# TODO: We should test if torch `topk`` is faster when `filter_viewed`` is ``False``
@@ -344,8 +346,6 @@ def recommend_u2i(
344346
k=k,
345347
filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + n_item_extra_tokens]
346348
sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal
347-
num_threads=recommend_n_threads,
348-
use_gpu=recommend_use_gpu_ranking and HAS_CUDA,
349349
)
350350
all_user_ids = user_ids[user_ids_indices]
351351
return all_user_ids, all_reco_ids, all_scores
@@ -371,12 +371,12 @@ def recommend_i2i(
371371
self.i2i_dist,
372372
item_embs, # [n_items + n_item_extra_tokens, n_factors]
373373
item_embs, # [n_items + n_item_extra_tokens, n_factors]
374+
num_threads=recommend_n_threads,
375+
use_gpu=recommend_use_gpu_ranking and HAS_CUDA,
374376
)
375377
return ranker.rank(
376378
subject_ids=target_ids, # model internal
377379
k=k,
378380
filter_pairs_csr=None,
379381
sorted_object_whitelist=sorted_item_ids_to_recommend, # model internal
380-
num_threads=recommend_n_threads,
381-
use_gpu=recommend_use_gpu_ranking and HAS_CUDA,
382382
)

0 commit comments

Comments
 (0)