|
9 | 9 | "1. How to apply SASRec and BERT4Rec transformer models using RecTools?\n", |
10 | 10 | "2. How do SASRec and BERT4Rec models work under the hood?\n", |
11 | 11 | "\n", |
12 | | - "Transformer models came to recommendation systems from NLP, where they are proved to have a significant impact. As transformers were applied to sequential data it is common to use them for session-based recommendations, where interactions are ordered by the date of their occurrence. In this tutorial focus is on SASRec and BERT4Rec - models which are considered as a common starting point for transformer application in RecSys. \n", |
| 12 | + "Transformer models came to recommendation systems from NLP, where they are proved to have a significant impact. As transformers were applied to sequential data it is common to use them for recommender systems, where interactions are ordered by the date of their occurrence. In this tutorial focus is on SASRec and BERT4Rec - models which are considered as a common starting point for transformer application in RecSys. \n", |
13 | 13 | "\n", |
14 | 14 | "### Why transformers from RecTools?\n", |
15 | 15 | "\n", |
|
440 | 440 | ], |
441 | 441 | "source": [ |
442 | 442 | "# Prepare test user\n", |
443 | | - "test_user = [176549] \n", |
444 | | - "print(interactions[interactions[\"user_id\"] == test_user[0]].shape)\n", |
445 | | - "interactions[interactions[\"user_id\"] == test_user[0]].head(2)" |
446 | | - ] |
447 | | - }, |
448 | | - { |
449 | | - "cell_type": "code", |
450 | | - "execution_count": 8, |
451 | | - "metadata": {}, |
452 | | - "outputs": [], |
453 | | - "source": [ |
454 | | - "# Prepare test item\n", |
455 | | - "test_item = [13865]" |
| 443 | + "test_user = 176549\n", |
| 444 | + "print(interactions[interactions[\"user_id\"] == test_user].shape)\n", |
| 445 | + "interactions[interactions[\"user_id\"] == test_user.head(2)" |
456 | 446 | ] |
457 | 447 | }, |
458 | 448 | { |
|
469 | 459 | "metadata": {}, |
470 | 460 | "source": [ |
471 | 461 | "## SASRec\n", |
472 | | - "SASRec is a transformer-based sequential model with <b>unidirectional</b> attention mechanism and <b>\"Shifted Sequence\"</b> training objective. Resulting user sequence latent representation is used to predict all items in user sequence at each sequence position where each item prediction is based only on previous item information.\n" |
| 462 | + "SASRec is a transformer-based sequential model with <b>unidirectional</b> attention mechanism and <b>\"Shifted Sequence\"</b> training objective. Resulting user sequence latent representation is used to predict all items in user sequence at each sequence position where each item prediction is based only on previous items.\n" |
473 | 463 | ] |
474 | 464 | }, |
475 | 465 | { |
|
886 | 876 | "source": [ |
887 | 877 | "%%time\n", |
888 | 878 | "recos = sasrec.recommend(\n", |
889 | | - " users=test_user, \n", |
| 879 | + " users=[test_user], \n", |
890 | 880 | " dataset=dataset,\n", |
891 | 881 | " k=3,\n", |
892 | 882 | " filter_viewed=True,\n", |
|
1027 | 1017 | "cell_type": "markdown", |
1028 | 1018 | "metadata": {}, |
1029 | 1019 | "source": [ |
1030 | | - "* Specify minimum number of user interactions in dataset that is required to include user to model training with `train_min_user_interaction`\n", |
| 1020 | + "* Specify minimum number of user interactions in dataset that is required to include user to model training with `train_min_user_interactions`\n", |
1031 | 1021 | "* Specify whether positional encoding should be used with `use_pos_emb`\n", |
1032 | 1022 | "* Specify whether key_padding_mask in multi-head attention should be used with `use_key_padding_mask`. BERT4Rec has it set to ``True`` by default. `SASRec` has it set to ``False`` by default because of explicit zero multiplication of padding embeddings inside transfomer layers that we inherited from the original implementation.\n", |
1033 | 1023 | "\n", |
|
1418 | 1408 | } |
1419 | 1409 | ], |
1420 | 1410 | "source": [ |
1421 | | - "plot_metrics = [{\"model\": cv_results_dict[\"model\"], \"MAP@10\": cv_results_dict[\"MAP@10\"], \"Serendipity@10\": cv_results_dict[\"Serendipity@10\"]} \n", |
1422 | | - " for cv_results_dict in cv_results[\"metrics\"]]\n", |
1423 | | - "\n", |
1424 | | - "models_metadata = [{\"model\": model_name, \n", |
1425 | | - " \"item_net_block_types\": \",\".join(block for block in [\"Id\", \"Cat\"] \n", |
1426 | | - " if re.search(block, str(model.get_params()[\"item_net_block_types\"]))),\n", |
1427 | | - " } \n", |
1428 | | - " for model_name, model in models.items() if model_name not in [\"popular\", \"ease\", \"bert4rec_softmax_ids_and_cat\"]]\n", |
1429 | | - "\n", |
| 1411 | + "models_metrics = pivot_results[[\"model\", \"MAP@10\", \"Serendipity@10\"]]\n", |
1430 | 1412 | "\n", |
| 1413 | + "models_to_skip_meta = [\"popular\", \"ease\", \"bert4rec_softmax_ids_and_cat\"]\n", |
| 1414 | + "models_metadata = [\n", |
| 1415 | + " {\n", |
| 1416 | + " \"model\": model_name, \n", |
| 1417 | + " \"item_net_block_types\": \",\".join(\n", |
| 1418 | + " block for block in [\"Id\", \"Cat\"] \n", |
| 1419 | + " if re.search(block, str(model.get_params()[\"item_net_block_types\"]))\n", |
| 1420 | + " ),\n", |
| 1421 | + " } \n", |
| 1422 | + " for model_name, model in models.items() if model_name not in models_to_skip_meta\n", |
| 1423 | + "]\n", |
1431 | 1424 | "\n", |
1432 | 1425 | "app = MetricsApp.construct(\n", |
1433 | | - " models_metrics=pd.DataFrame(plot_metrics),\n", |
| 1426 | + " models_metrics=models_metrics,\n", |
1434 | 1427 | " models_metadata=pd.DataFrame(models_metadata),\n", |
1435 | | - " scatter_kwargs={\"color_discrete_sequence\": px.colors.qualitative.Dark24,\n", |
1436 | | - " \"symbol_sequence\": ['circle', 'square', 'diamond', 'cross', 'x', 'star', 'pentagon'],}\n", |
| 1428 | + " scatter_kwargs={\n", |
| 1429 | + " \"color_discrete_sequence\": px.colors.qualitative.Dark24,\n", |
| 1430 | + " \"symbol_sequence\": ['circle', 'square', 'diamond', 'cross', 'x', 'star', 'pentagon'],\n", |
| 1431 | + " }\n", |
1437 | 1432 | ")\n", |
1438 | 1433 | "fig = app.fig\n", |
1439 | | - "fig.update_layout(title=\"Model CV metrics\",\n", |
1440 | | - " font={\"size\": 15})\n", |
| 1434 | + "fig.update_layout(title=\"Model CV metrics\", font={\"size\": 15})\n", |
1441 | 1435 | "fig.update_traces(marker={'size': 9})\n", |
1442 | 1436 | "fig.show(\"png\")" |
1443 | 1437 | ] |
|
1499 | 1493 | } |
1500 | 1494 | ], |
1501 | 1495 | "source": [ |
1502 | | - "items[items['item_id'] == test_item[0]][\"title\"]" |
| 1496 | + "# Prepare test item\n", |
| 1497 | + "test_item = 13865\n", |
| 1498 | + "items.loc[items['item_id'] == test_item, \"title\"]" |
1503 | 1499 | ] |
1504 | 1500 | }, |
1505 | 1501 | { |
|
1587 | 1583 | "source": [ |
1588 | 1584 | "%%time\n", |
1589 | 1585 | "recos = sasrec.recommend_to_items(\n", |
1590 | | - " target_items=test_item, \n", |
| 1586 | + " target_items=[test_item], \n", |
1591 | 1587 | " dataset=dataset,\n", |
1592 | 1588 | " k=3,\n", |
1593 | 1589 | " filter_itself=True,\n", |
|
1601 | 1597 | "metadata": {}, |
1602 | 1598 | "source": [ |
1603 | 1599 | "## Inference tricks (model known items and inference for cold users)\n", |
1604 | | - "It may happen that SASRec or BERT4Rec filters out users with less than `train_min_user_interaction` interactions during the train stage. However, it is still possible to make recommendations for those users if they have at least one interaction in history with an item that was present at training.\n", |
| 1600 | + "It may happen that SASRec or BERT4Rec filters out users with less than `train_min_user_interactions` interactions during the train stage. However, it is still possible to make recommendations for those users if they have at least one interaction in history with an item that was present at training.\n", |
1605 | 1601 | "\n", |
1606 | 1602 | "As an example consider user 324373, for whom there is only one interaction in the dataset." |
1607 | 1603 | ] |
|
1669 | 1665 | ], |
1670 | 1666 | "source": [ |
1671 | 1667 | "# Prepare test user with 1 interaction\n", |
1672 | | - "test_user_recs = [324373] \n", |
1673 | | - "print(interactions[interactions[\"user_id\"] == test_user_recs[0]].shape)\n", |
1674 | | - "interactions[interactions[\"user_id\"] == test_user_recs[0]]" |
| 1668 | + "test_user_one = 324373\n", |
| 1669 | + "print(interactions[interactions[\"user_id\"] == test_user_one].shape)\n", |
| 1670 | + "interactions[interactions[\"user_id\"] == test_user_one]" |
1675 | 1671 | ] |
1676 | 1672 | }, |
1677 | 1673 | { |
|
1759 | 1755 | "source": [ |
1760 | 1756 | "%%time\n", |
1761 | 1757 | "recos = sasrec.recommend(\n", |
1762 | | - " users=test_user_recs, \n", |
| 1758 | + " users=[test_user_one], \n", |
1763 | 1759 | " dataset=dataset,\n", |
1764 | 1760 | " k=3,\n", |
1765 | 1761 | " filter_viewed=True,\n", |
|
1838 | 1834 | ], |
1839 | 1835 | "source": [ |
1840 | 1836 | "# Prepare test user with items unknown by the model\n", |
1841 | | - "test_user_no_recs = [14630] \n", |
1842 | | - "print(interactions[interactions[\"user_id\"] == test_user_no_recs[0]].shape)\n", |
1843 | | - "interactions[interactions[\"user_id\"] == test_user_no_recs[0]].head(2)" |
| 1837 | + "test_user_no_recs = 14630\n", |
| 1838 | + "print(interactions[interactions[\"user_id\"] == test_user_no_recs.shape)\n", |
| 1839 | + "interactions[interactions[\"user_id\"] == test_user_no_recs.head(2)" |
1844 | 1840 | ] |
1845 | 1841 | }, |
1846 | 1842 | { |
|
1915 | 1911 | "source": [ |
1916 | 1912 | "%%time\n", |
1917 | 1913 | "recos = sasrec.recommend(\n", |
1918 | | - " users=test_user_no_recs, \n", |
| 1914 | + " users=[test_user_no_recs], \n", |
1919 | 1915 | " dataset=dataset,\n", |
1920 | 1916 | " k=3,\n", |
1921 | 1917 | " filter_viewed=True,\n", |
1922 | | - " on_unsupported_targets=\"warn\"\n", |
| 1918 | + " on_unsupported_targets=\"ignore\" # prevent raising an error\n", |
1923 | 1919 | ")\n", |
1924 | 1920 | "recos.merge(items[[\"item_id\", \"title_orig\"]], on=\"item_id\").sort_values([\"user_id\", \"rank\"])" |
1925 | 1921 | ] |
|
2155 | 2151 | ], |
2156 | 2152 | "metadata": { |
2157 | 2153 | "kernelspec": { |
2158 | | - "display_name": "venv", |
| 2154 | + "display_name": "Python 3 (ipykernel)", |
2159 | 2155 | "language": "python", |
2160 | 2156 | "name": "python3" |
2161 | 2157 | }, |
|
2169 | 2165 | "name": "python", |
2170 | 2166 | "nbconvert_exporter": "python", |
2171 | 2167 | "pygments_lexer": "ipython3", |
2172 | | - "version": "3.9.12" |
| 2168 | + "version": "3.7.10" |
2173 | 2169 | } |
2174 | 2170 | }, |
2175 | 2171 | "nbformat": 4, |
|
0 commit comments