diff --git a/HW-3.1.ipynb b/HW-3.1.ipynb
new file mode 100644
index 00000000..c05a3b71
--- /dev/null
+++ b/HW-3.1.ipynb
@@ -0,0 +1,4027 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "398a86d9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pprint import pprint\n",
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "import sys\n",
+ "sys.path.append('../')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "8dbe6bf0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import plotly.express as px\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import scipy as sp\n",
+ "import requests\n",
+ "from tqdm.auto import tqdm\n",
+ "from scipy.stats import mode\n",
+ "from implicit.nearest_neighbours import CosineRecommender, TFIDFRecommender, BM25Recommender\n",
+ "from rectools import Columns\n",
+ "from rectools.model_selection import TimeRangeSplitter\n",
+ "from rectools.metrics import Precision, Recall, MAP, MeanInvUserFreq, Serendipity, calc_metrics\n",
+ "from rectools.dataset.interactions import Interactions\n",
+ "\n",
+ "from service.utils.user_knn import UserKnn"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1baa79f",
+ "metadata": {},
+ "source": [
+ "# Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "f2a9e540",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((5476251, 5), (840197, 5), (15963, 14))"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "interactions = pd.read_csv('../data/kion_train/interactions.csv')\n",
+ "users = pd.read_csv('../data/kion_train/users.csv')\n",
+ "items = pd.read_csv('../data/kion_train/items.csv')\n",
+ "\n",
+ "interactions.shape, users.shape, items.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "456d25f4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "interactions.rename(\n",
+ " columns={\n",
+ " 'last_watch_dt': Columns.Datetime,\n",
+ " 'total_dur': Columns.Weight\n",
+ " }, \n",
+ " inplace=True) \n",
+ "\n",
+ "interactions[Columns.Datetime] = pd.to_datetime(interactions[Columns.Datetime])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6f7b9b0c",
+ "metadata": {},
+ "source": [
+ "## Intersection"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "7c9c0c94",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " datetime | \n",
+ " weight | \n",
+ " watched_pct | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 176549 | \n",
+ " 9506 | \n",
+ " 2021-05-11 | \n",
+ " 4250 | \n",
+ " 72.0 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 699317 | \n",
+ " 1659 | \n",
+ " 2021-05-29 | \n",
+ " 8317 | \n",
+ " 100.0 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 656683 | \n",
+ " 7107 | \n",
+ " 2021-05-09 | \n",
+ " 10 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 864613 | \n",
+ " 7638 | \n",
+ " 2021-07-05 | \n",
+ " 14483 | \n",
+ " 100.0 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 964868 | \n",
+ " 9506 | \n",
+ " 2021-04-30 | \n",
+ " 6725 | \n",
+ " 100.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476246 | \n",
+ " 648596 | \n",
+ " 12225 | \n",
+ " 2021-08-13 | \n",
+ " 76 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476247 | \n",
+ " 546862 | \n",
+ " 9673 | \n",
+ " 2021-04-13 | \n",
+ " 2308 | \n",
+ " 49.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476248 | \n",
+ " 697262 | \n",
+ " 15297 | \n",
+ " 2021-08-20 | \n",
+ " 18307 | \n",
+ " 63.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476249 | \n",
+ " 384202 | \n",
+ " 16197 | \n",
+ " 2021-04-19 | \n",
+ " 6203 | \n",
+ " 100.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476250 | \n",
+ " 319709 | \n",
+ " 4436 | \n",
+ " 2021-08-15 | \n",
+ " 3921 | \n",
+ " 45.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id item_id datetime weight watched_pct\n",
+ "0 176549 9506 2021-05-11 4250 72.0\n",
+ "1 699317 1659 2021-05-29 8317 100.0\n",
+ "2 656683 7107 2021-05-09 10 0.0\n",
+ "3 864613 7638 2021-07-05 14483 100.0\n",
+ "4 964868 9506 2021-04-30 6725 100.0\n",
+ "5476246 648596 12225 2021-08-13 76 0.0\n",
+ "5476247 546862 9673 2021-04-13 2308 49.0\n",
+ "5476248 697262 15297 2021-08-20 18307 63.0\n",
+ "5476249 384202 16197 2021-04-19 6203 100.0\n",
+ "5476250 319709 4436 2021-08-15 3921 45.0"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.concat([interactions.head(), interactions.tail()])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "c5c3ce6c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Interactions dataframe shape: (5476251, 5)\n",
+ "Unique users in interactions: 962179\n",
+ "Unique items in interactions: 15706\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"Interactions dataframe shape: {interactions.shape}\")\n",
+ "print(f\"Unique users in interactions: {interactions[Columns.User].nunique()}\")\n",
+ "print(f\"Unique items in interactions: {interactions[Columns.Item].nunique()}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "0214a978",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "min date in interactions: 2021-03-13 00:00:00\n",
+ "max date in interactions: 2021-08-22 00:00:00\n"
+ ]
+ }
+ ],
+ "source": [
+ "max_date = interactions[Columns.Datetime].max()\n",
+ "min_date = interactions[Columns.Datetime].min()\n",
+ "\n",
+ "print(f\"min date in interactions: {min_date}\")\n",
+ "print(f\"max date in interactions: {max_date}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "7829e796",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "RangeIndex: 5476251 entries, 0 to 5476250\n",
+ "Data columns (total 5 columns):\n",
+ " # Column Dtype \n",
+ "--- ------ ----- \n",
+ " 0 user_id int64 \n",
+ " 1 item_id int64 \n",
+ " 2 datetime datetime64[ns]\n",
+ " 3 weight int64 \n",
+ " 4 watched_pct float64 \n",
+ "dtypes: datetime64[ns](1), float64(1), int64(3)\n",
+ "memory usage: 208.9 MB\n"
+ ]
+ }
+ ],
+ "source": [
+ "interactions.info()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "57cddf34",
+ "metadata": {},
+ "source": [
+ "## Users"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "de5dea16",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " age | \n",
+ " income | \n",
+ " sex | \n",
+ " kids_flg | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 973171 | \n",
+ " age_25_34 | \n",
+ " income_60_90 | \n",
+ " М | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 962099 | \n",
+ " age_18_24 | \n",
+ " income_20_40 | \n",
+ " М | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1047345 | \n",
+ " age_45_54 | \n",
+ " income_40_60 | \n",
+ " Ж | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 721985 | \n",
+ " age_45_54 | \n",
+ " income_20_40 | \n",
+ " Ж | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 704055 | \n",
+ " age_35_44 | \n",
+ " income_60_90 | \n",
+ " Ж | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 840192 | \n",
+ " 339025 | \n",
+ " age_65_inf | \n",
+ " income_0_20 | \n",
+ " Ж | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 840193 | \n",
+ " 983617 | \n",
+ " age_18_24 | \n",
+ " income_20_40 | \n",
+ " Ж | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " | 840194 | \n",
+ " 251008 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 840195 | \n",
+ " 590706 | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " Ж | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 840196 | \n",
+ " 166555 | \n",
+ " age_65_inf | \n",
+ " income_20_40 | \n",
+ " Ж | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id age income sex kids_flg\n",
+ "0 973171 age_25_34 income_60_90 М 1\n",
+ "1 962099 age_18_24 income_20_40 М 0\n",
+ "2 1047345 age_45_54 income_40_60 Ж 0\n",
+ "3 721985 age_45_54 income_20_40 Ж 0\n",
+ "4 704055 age_35_44 income_60_90 Ж 0\n",
+ "840192 339025 age_65_inf income_0_20 Ж 0\n",
+ "840193 983617 age_18_24 income_20_40 Ж 1\n",
+ "840194 251008 NaN NaN NaN 0\n",
+ "840195 590706 NaN NaN Ж 0\n",
+ "840196 166555 age_65_inf income_20_40 Ж 0"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.concat([users.head(), users.tail()])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "e4e6d2f5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Users dataframe shape (840197, 5)\n",
+ "Unique users: 840197\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"Users dataframe shape {users.shape}\")\n",
+ "print(f\"Unique users: {users['user_id'].nunique()}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "98b4ff6c",
+ "metadata": {},
+ "source": [
+ "## Items"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "19b43ff0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item_id | \n",
+ " content_type | \n",
+ " title | \n",
+ " title_orig | \n",
+ " release_year | \n",
+ " genres | \n",
+ " countries | \n",
+ " for_kids | \n",
+ " age_rating | \n",
+ " studios | \n",
+ " directors | \n",
+ " actors | \n",
+ " description | \n",
+ " keywords | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 10711 | \n",
+ " film | \n",
+ " Поговори с ней | \n",
+ " Hable con ella | \n",
+ " 2002.0 | \n",
+ " драмы, зарубежные, детективы, мелодрамы | \n",
+ " Испания | \n",
+ " NaN | \n",
+ " 16.0 | \n",
+ " NaN | \n",
+ " Педро Альмодовар | \n",
+ " Адольфо Фернандес, Ана Фернандес, Дарио Гранди... | \n",
+ " Мелодрама легендарного Педро Альмодовара «Пого... | \n",
+ " Поговори, ней, 2002, Испания, друзья, любовь, ... | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2508 | \n",
+ " film | \n",
+ " Голые перцы | \n",
+ " Search Party | \n",
+ " 2014.0 | \n",
+ " зарубежные, приключения, комедии | \n",
+ " США | \n",
+ " NaN | \n",
+ " 16.0 | \n",
+ " NaN | \n",
+ " Скот Армстронг | \n",
+ " Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... | \n",
+ " Уморительная современная комедия на популярную... | \n",
+ " Голые, перцы, 2014, США, друзья, свадьбы, прео... | \n",
+ "
\n",
+ " \n",
+ " | 15961 | \n",
+ " 4538 | \n",
+ " series | \n",
+ " Среди камней | \n",
+ " Darklands | \n",
+ " 2019.0 | \n",
+ " драмы, спорт, криминал | \n",
+ " Россия | \n",
+ " 0.0 | \n",
+ " 18.0 | \n",
+ " NaN | \n",
+ " Марк О’Коннор, Конор МакМахон | \n",
+ " Дэйн Уайт О’Хара, Томас Кэйн-Бирн, Джудит Родд... | \n",
+ " Семнадцатилетний Дэмиен мечтает вырваться за п... | \n",
+ " Среди, камней, 2019, Россия | \n",
+ "
\n",
+ " \n",
+ " | 15962 | \n",
+ " 3206 | \n",
+ " series | \n",
+ " Гоша | \n",
+ " NaN | \n",
+ " 2019.0 | \n",
+ " комедии | \n",
+ " Россия | \n",
+ " 0.0 | \n",
+ " 16.0 | \n",
+ " NaN | \n",
+ " Михаил Миронов | \n",
+ " Мкртыч Арзуманян, Виктория Рунцова | \n",
+ " Добродушный Гоша не может выйти из дома, чтобы... | \n",
+ " Гоша, 2019, Россия | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " item_id content_type title title_orig release_year \\\n",
+ "0 10711 film Поговори с ней Hable con ella 2002.0 \n",
+ "1 2508 film Голые перцы Search Party 2014.0 \n",
+ "15961 4538 series Среди камней Darklands 2019.0 \n",
+ "15962 3206 series Гоша NaN 2019.0 \n",
+ "\n",
+ " genres countries for_kids \\\n",
+ "0 драмы, зарубежные, детективы, мелодрамы Испания NaN \n",
+ "1 зарубежные, приключения, комедии США NaN \n",
+ "15961 драмы, спорт, криминал Россия 0.0 \n",
+ "15962 комедии Россия 0.0 \n",
+ "\n",
+ " age_rating studios directors \\\n",
+ "0 16.0 NaN Педро Альмодовар \n",
+ "1 16.0 NaN Скот Армстронг \n",
+ "15961 18.0 NaN Марк О’Коннор, Конор МакМахон \n",
+ "15962 16.0 NaN Михаил Миронов \n",
+ "\n",
+ " actors \\\n",
+ "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n",
+ "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n",
+ "15961 Дэйн Уайт О’Хара, Томас Кэйн-Бирн, Джудит Родд... \n",
+ "15962 Мкртыч Арзуманян, Виктория Рунцова \n",
+ "\n",
+ " description \\\n",
+ "0 Мелодрама легендарного Педро Альмодовара «Пого... \n",
+ "1 Уморительная современная комедия на популярную... \n",
+ "15961 Семнадцатилетний Дэмиен мечтает вырваться за п... \n",
+ "15962 Добродушный Гоша не может выйти из дома, чтобы... \n",
+ "\n",
+ " keywords \n",
+ "0 Поговори, ней, 2002, Испания, друзья, любовь, ... \n",
+ "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... \n",
+ "15961 Среди, камней, 2019, Россия \n",
+ "15962 Гоша, 2019, Россия "
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.concat([items.head(2), items.tail(2)])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "8c8fb319",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Items dataframe shape (15963, 14)\n",
+ "Unique item_id: 15963\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"Items dataframe shape {items.shape}\")\n",
+ "print(f\"Unique item_id: {items['item_id'].nunique()}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2b35b460",
+ "metadata": {},
+ "source": [
+ "# userkNN model CV"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "f60e6ecb",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "alignmentgroup": "True",
+ "hovertemplate": "variable=user_id
datetime=%{x}
value=%{y}",
+ "legendgroup": "user_id",
+ "marker": {
+ "color": "#636efa",
+ "pattern": {
+ "shape": ""
+ }
+ },
+ "name": "user_id",
+ "offsetgroup": "user_id",
+ "orientation": "v",
+ "showlegend": true,
+ "textposition": "auto",
+ "type": "bar",
+ "x": [
+ "2021-03-13T00:00:00",
+ "2021-03-14T00:00:00",
+ "2021-03-15T00:00:00",
+ "2021-03-16T00:00:00",
+ "2021-03-17T00:00:00",
+ "2021-03-18T00:00:00",
+ "2021-03-19T00:00:00",
+ "2021-03-20T00:00:00",
+ "2021-03-21T00:00:00",
+ "2021-03-22T00:00:00",
+ "2021-03-23T00:00:00",
+ "2021-03-24T00:00:00",
+ "2021-03-25T00:00:00",
+ "2021-03-26T00:00:00",
+ "2021-03-27T00:00:00",
+ "2021-03-28T00:00:00",
+ "2021-03-29T00:00:00",
+ "2021-03-30T00:00:00",
+ "2021-03-31T00:00:00",
+ "2021-04-01T00:00:00",
+ "2021-04-02T00:00:00",
+ "2021-04-03T00:00:00",
+ "2021-04-04T00:00:00",
+ "2021-04-05T00:00:00",
+ "2021-04-06T00:00:00",
+ "2021-04-07T00:00:00",
+ "2021-04-08T00:00:00",
+ "2021-04-09T00:00:00",
+ "2021-04-10T00:00:00",
+ "2021-04-11T00:00:00",
+ "2021-04-12T00:00:00",
+ "2021-04-13T00:00:00",
+ "2021-04-14T00:00:00",
+ "2021-04-15T00:00:00",
+ "2021-04-16T00:00:00",
+ "2021-04-17T00:00:00",
+ "2021-04-18T00:00:00",
+ "2021-04-19T00:00:00",
+ "2021-04-20T00:00:00",
+ "2021-04-21T00:00:00",
+ "2021-04-22T00:00:00",
+ "2021-04-23T00:00:00",
+ "2021-04-24T00:00:00",
+ "2021-04-25T00:00:00",
+ "2021-04-26T00:00:00",
+ "2021-04-27T00:00:00",
+ "2021-04-28T00:00:00",
+ "2021-04-29T00:00:00",
+ "2021-04-30T00:00:00",
+ "2021-05-01T00:00:00",
+ "2021-05-02T00:00:00",
+ "2021-05-03T00:00:00",
+ "2021-05-04T00:00:00",
+ "2021-05-05T00:00:00",
+ "2021-05-06T00:00:00",
+ "2021-05-07T00:00:00",
+ "2021-05-08T00:00:00",
+ "2021-05-09T00:00:00",
+ "2021-05-10T00:00:00",
+ "2021-05-11T00:00:00",
+ "2021-05-12T00:00:00",
+ "2021-05-13T00:00:00",
+ "2021-05-14T00:00:00",
+ "2021-05-15T00:00:00",
+ "2021-05-16T00:00:00",
+ "2021-05-17T00:00:00",
+ "2021-05-18T00:00:00",
+ "2021-05-19T00:00:00",
+ "2021-05-20T00:00:00",
+ "2021-05-21T00:00:00",
+ "2021-05-22T00:00:00",
+ "2021-05-23T00:00:00",
+ "2021-05-24T00:00:00",
+ "2021-05-25T00:00:00",
+ "2021-05-26T00:00:00",
+ "2021-05-27T00:00:00",
+ "2021-05-28T00:00:00",
+ "2021-05-29T00:00:00",
+ "2021-05-30T00:00:00",
+ "2021-05-31T00:00:00",
+ "2021-06-01T00:00:00",
+ "2021-06-02T00:00:00",
+ "2021-06-03T00:00:00",
+ "2021-06-04T00:00:00",
+ "2021-06-05T00:00:00",
+ "2021-06-06T00:00:00",
+ "2021-06-07T00:00:00",
+ "2021-06-08T00:00:00",
+ "2021-06-09T00:00:00",
+ "2021-06-10T00:00:00",
+ "2021-06-11T00:00:00",
+ "2021-06-12T00:00:00",
+ "2021-06-13T00:00:00",
+ "2021-06-14T00:00:00",
+ "2021-06-15T00:00:00",
+ "2021-06-16T00:00:00",
+ "2021-06-17T00:00:00",
+ "2021-06-18T00:00:00",
+ "2021-06-19T00:00:00",
+ "2021-06-20T00:00:00",
+ "2021-06-21T00:00:00",
+ "2021-06-22T00:00:00",
+ "2021-06-23T00:00:00",
+ "2021-06-24T00:00:00",
+ "2021-06-25T00:00:00",
+ "2021-06-26T00:00:00",
+ "2021-06-27T00:00:00",
+ "2021-06-28T00:00:00",
+ "2021-06-29T00:00:00",
+ "2021-06-30T00:00:00",
+ "2021-07-01T00:00:00",
+ "2021-07-02T00:00:00",
+ "2021-07-03T00:00:00",
+ "2021-07-04T00:00:00",
+ "2021-07-05T00:00:00",
+ "2021-07-06T00:00:00",
+ "2021-07-07T00:00:00",
+ "2021-07-08T00:00:00",
+ "2021-07-09T00:00:00",
+ "2021-07-10T00:00:00",
+ "2021-07-11T00:00:00",
+ "2021-07-12T00:00:00",
+ "2021-07-13T00:00:00",
+ "2021-07-14T00:00:00",
+ "2021-07-15T00:00:00",
+ "2021-07-16T00:00:00",
+ "2021-07-17T00:00:00",
+ "2021-07-18T00:00:00",
+ "2021-07-19T00:00:00",
+ "2021-07-20T00:00:00",
+ "2021-07-21T00:00:00",
+ "2021-07-22T00:00:00",
+ "2021-07-23T00:00:00",
+ "2021-07-24T00:00:00",
+ "2021-07-25T00:00:00",
+ "2021-07-26T00:00:00",
+ "2021-07-27T00:00:00",
+ "2021-07-28T00:00:00",
+ "2021-07-29T00:00:00",
+ "2021-07-30T00:00:00",
+ "2021-07-31T00:00:00",
+ "2021-08-01T00:00:00",
+ "2021-08-02T00:00:00",
+ "2021-08-03T00:00:00",
+ "2021-08-04T00:00:00",
+ "2021-08-05T00:00:00",
+ "2021-08-06T00:00:00",
+ "2021-08-07T00:00:00",
+ "2021-08-08T00:00:00",
+ "2021-08-09T00:00:00",
+ "2021-08-10T00:00:00",
+ "2021-08-11T00:00:00",
+ "2021-08-12T00:00:00",
+ "2021-08-13T00:00:00",
+ "2021-08-14T00:00:00",
+ "2021-08-15T00:00:00",
+ "2021-08-16T00:00:00",
+ "2021-08-17T00:00:00",
+ "2021-08-18T00:00:00",
+ "2021-08-19T00:00:00",
+ "2021-08-20T00:00:00",
+ "2021-08-21T00:00:00",
+ "2021-08-22T00:00:00"
+ ],
+ "xaxis": "x",
+ "y": [
+ 16104,
+ 15606,
+ 12363,
+ 12643,
+ 12753,
+ 12788,
+ 13657,
+ 15346,
+ 15560,
+ 12752,
+ 13147,
+ 13435,
+ 12698,
+ 13909,
+ 15657,
+ 16112,
+ 12783,
+ 13101,
+ 13460,
+ 12966,
+ 14084,
+ 15431,
+ 15346,
+ 12642,
+ 12528,
+ 13129,
+ 13827,
+ 14416,
+ 15937,
+ 16046,
+ 12835,
+ 12322,
+ 12451,
+ 12275,
+ 13342,
+ 15464,
+ 16275,
+ 14286,
+ 20420,
+ 23200,
+ 21274,
+ 22127,
+ 26161,
+ 28964,
+ 21625,
+ 22590,
+ 21406,
+ 19987,
+ 21406,
+ 23479,
+ 24767,
+ 26267,
+ 25983,
+ 23941,
+ 23510,
+ 23201,
+ 27550,
+ 25986,
+ 27242,
+ 20957,
+ 20578,
+ 20729,
+ 21152,
+ 24530,
+ 24914,
+ 20960,
+ 20574,
+ 21561,
+ 22712,
+ 25697,
+ 27895,
+ 29978,
+ 24317,
+ 23667,
+ 22529,
+ 23881,
+ 24131,
+ 29035,
+ 31308,
+ 26821,
+ 26587,
+ 27577,
+ 28683,
+ 33150,
+ 34795,
+ 37096,
+ 31402,
+ 31107,
+ 32896,
+ 38964,
+ 37935,
+ 38619,
+ 42125,
+ 38973,
+ 35993,
+ 57686,
+ 41440,
+ 42174,
+ 43679,
+ 47989,
+ 39127,
+ 39693,
+ 41688,
+ 38394,
+ 41428,
+ 45898,
+ 48903,
+ 43301,
+ 43887,
+ 67749,
+ 53900,
+ 46642,
+ 48832,
+ 52812,
+ 43375,
+ 41380,
+ 41163,
+ 41592,
+ 40955,
+ 44798,
+ 46250,
+ 42487,
+ 43764,
+ 43128,
+ 43010,
+ 44878,
+ 49714,
+ 54139,
+ 45541,
+ 44431,
+ 44422,
+ 46313,
+ 46911,
+ 50317,
+ 54378,
+ 48531,
+ 49324,
+ 50267,
+ 50585,
+ 53121,
+ 59499,
+ 62128,
+ 53495,
+ 52181,
+ 51911,
+ 51047,
+ 53745,
+ 59316,
+ 61454,
+ 52794,
+ 53712,
+ 55617,
+ 56497,
+ 55843,
+ 61644,
+ 66546,
+ 54546,
+ 54311,
+ 56789,
+ 58640,
+ 60145,
+ 68834,
+ 71171
+ ],
+ "yaxis": "y"
+ }
+ ],
+ "layout": {
+ "barmode": "relative",
+ "legend": {
+ "title": {
+ "text": "variable"
+ },
+ "tracegroupgap": 0
+ },
+ "margin": {
+ "t": 60
+ },
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "xaxis": {
+ "anchor": "y",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "datetime"
+ }
+ },
+ "yaxis": {
+ "anchor": "x",
+ "domain": [
+ 0,
+ 1
+ ],
+ "title": {
+ "text": "value"
+ }
+ }
+ }
+ },
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig = px.bar(interactions.groupby(Columns.Datetime)[Columns.User].agg('count'))\n",
+ "fig.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "43f216d0",
+ "metadata": {},
+ "source": [
+ "Из графика видны **недельные тенденции** просмотров, поэтому следует fold-ы разделять по 7 дней, но т.к. на семинаре дали \"намек\", что private dataset имеет количество дней, меньшее чем 7. Поэтому фолды будут разбиваться на **5 и 7 дней**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "07fbdb30",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "6"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.to_datetime('23-05-2021', format='%d-%m-%Y').weekday()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2ff625b2",
+ "metadata": {},
+ "source": [
+ "### train test split"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "759ba346",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def create_data_range(\n",
+ " last_date: pd.Timestamp, \n",
+ " n_folds: int = 7, \n",
+ " unit: str = \"W\", \n",
+ " n_units: int = 1, \n",
+ " show: bool = True,\n",
+ "):\n",
+ " periods = n_folds + 1\n",
+ " freq = f\"{n_units}{unit}\"\n",
+ " \n",
+ " start_date = last_date - pd.Timedelta(n_folds * n_units + n_units, unit=unit) \n",
+ " \n",
+ " date_range = pd.date_range(start=start_date, periods=periods, freq=freq, tz=last_date.tz)\n",
+ " \n",
+ " if show:\n",
+ " print(\n",
+ " f\"start_date: {start_date}\\n\"\n",
+ " f\"last_date: {last_date}\\n\"\n",
+ " f\"periods: {periods}\\n\"\n",
+ " f\"freq: {freq}\\n\"\n",
+ " f\"Test fold borders: {date_range.values.astype('datetime64[D]')}\\n\"\n",
+ " )\n",
+ " \n",
+ " return date_range"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "38bfd397",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "CONFIG_CV = {\n",
+ " \"cv_v1\": {\n",
+ " \"n_folds\": 7,\n",
+ " \"unit\": \"W\",\n",
+ " \"n_units\": 1,\n",
+ " },\n",
+ " \"cv_v2\": {\n",
+ " \"n_folds\": 7,\n",
+ " \"unit\": \"D\",\n",
+ " \"n_units\": 5,\n",
+ " }, \n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "f518e089",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Timestamp('2021-08-22 00:00:00')"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "last_date = interactions[Columns.Datetime].max().normalize()\n",
+ "last_date"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "1fd68b9b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "***Folds v1***\n",
+ "start_date: 2021-07-13 00:00:00\n",
+ "last_date: 2021-08-22 00:00:00\n",
+ "periods: 8\n",
+ "freq: 5D\n",
+ "Test fold borders: ['2021-07-13' '2021-07-18' '2021-07-23' '2021-07-28' '2021-08-02'\n",
+ " '2021-08-07' '2021-08-12' '2021-08-17']\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"***Folds v1***\")\n",
+ "date_range_v1 = create_data_range(\n",
+ " last_date, \n",
+ " n_folds=CONFIG_CV[\"cv_v2\"][\"n_folds\"], \n",
+ " unit=CONFIG_CV[\"cv_v2\"][\"unit\"], \n",
+ " n_units=CONFIG_CV[\"cv_v2\"][\"n_units\"]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "efc59555",
+ "metadata": {},
+ "source": [
+ "**генерируем фолды** "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "9fae43f6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Real number of folds: 7\n"
+ ]
+ }
+ ],
+ "source": [
+ "cv_v1 = TimeRangeSplitter(\n",
+ " date_range=date_range_v1,\n",
+ " filter_already_seen=True,\n",
+ " filter_cold_items=True,\n",
+ " filter_cold_users=True,\n",
+ ")\n",
+ "print(f\"Real number of folds: {cv_v1.get_n_splits(Interactions(interactions))}\")\n",
+ "\n",
+ "CV = [cv_v1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e15a83a7",
+ "metadata": {},
+ "source": [
+ "**Формируем метрики**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "8f7742c6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "metrics = {\n",
+ " \"prec@10\": Precision(k=10),\n",
+ " \"recall@10\": Recall(k=10),\n",
+ " \"MAP@10\": MAP(k=10),\n",
+ " \"novelty\": MeanInvUserFreq(k=10),\n",
+ " \"serendipity\": Serendipity(k=10),\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "b21a1ecf",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'cosine_userknn_K30': ,\n",
+ " 'tfidf_userknn_K30': ,\n",
+ " 'bm25_userknn_K30': ,\n",
+ " 'cosine_userknn_K40': ,\n",
+ " 'tfidf_userknn_K40': ,\n",
+ " 'bm25_userknn_K40': }"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "K = [30, 40]\n",
+ "models = dict()\n",
+ "\n",
+ "for k in K:\n",
+ " models[f\"cosine_userknn_K{k}\"] = CosineRecommender(K=k)\n",
+ " models[f\"tfidf_userknn_K{k}\"] = TFIDFRecommender(K=k)\n",
+ " models[f\"bm25_userknn_K{k}\"] = BM25Recommender(K=k)\n",
+ "\n",
+ "models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0103149a",
+ "metadata": {},
+ "source": [
+ "## Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "e78b8221",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "N_USERS = 50"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "50dcff0b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%time\n",
+ "\n",
+ "results = []\n",
+ "\n",
+ "for idx, cv in enumerate(CV):\n",
+ " print(f\"\\n CV version {idx}\")\n",
+ " fold_iterator = cv.split(Interactions(interactions), collect_fold_stats=True)\n",
+ "\n",
+ " for i_fold, (train_ids, test_ids, fold_info) in enumerate(fold_iterator):\n",
+ " print(f\"\\n==================== Fold {i_fold}\")\n",
+ " pprint(fold_info)\n",
+ "\n",
+ " df_train = interactions.iloc[train_ids].copy()\n",
+ " df_test = interactions.iloc[test_ids][Columns.UserItem].copy()\n",
+ "\n",
+ " catalog = df_train[Columns.Item].unique()\n",
+ "\n",
+ " for model_name, model in models.items():\n",
+ " userknn_model = UserKnn(model=model, N_users=N_USERS, use_weight_idf=True)\n",
+ " userknn_model.fit(df_train)\n",
+ "\n",
+ " if 'bm25' in model_name:\n",
+ " recos = userknn_model.predict(df_test, bmp25=True)\n",
+ " else:\n",
+ " recos = userknn_model.predict(df_test)\n",
+ "\n",
+ " metric_values = calc_metrics(\n",
+ " metrics,\n",
+ " reco=recos,\n",
+ " interactions=df_test,\n",
+ " prev_interactions=df_train,\n",
+ " catalog=catalog,\n",
+ " )\n",
+ "\n",
+ " full_model_name = f\"{model_name}_cv-{idx}\"\n",
+ " fold = {\"fold\": i_fold, \"model\": full_model_name}\n",
+ " fold.update(metric_values)\n",
+ " results.append(fold)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "708ec5c2",
+ "metadata": {},
+ "source": [
+ "Работало больше 10 часов, случайно при перезапуске ноутбука была вызвана ячейка и остановлена, поэтому завершилась с ошибкой, поэтому ошибку убрали для лучшего вида"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "id": "d7e2ffa7",
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " fold | \n",
+ " model | \n",
+ " prec@10 | \n",
+ " recall@10 | \n",
+ " MAP@10 | \n",
+ " novelty | \n",
+ " serendipity | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0 | \n",
+ " cosine_userknn_K30_cv-0 | \n",
+ " 0.003557 | \n",
+ " 0.021128 | \n",
+ " 0.003695 | \n",
+ " 8.331491 | \n",
+ " 0.000040 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 0 | \n",
+ " tfidf_userknn_K30_cv-0 | \n",
+ " 0.006439 | \n",
+ " 0.039102 | \n",
+ " 0.007335 | \n",
+ " 8.155051 | \n",
+ " 0.000048 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0 | \n",
+ " bm25_userknn_K30_cv-0 | \n",
+ " 0.002593 | \n",
+ " 0.013494 | \n",
+ " 0.002531 | \n",
+ " 9.398467 | \n",
+ " 0.000081 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 0 | \n",
+ " cosine_userknn_K40_cv-0 | \n",
+ " 0.003282 | \n",
+ " 0.019323 | \n",
+ " 0.003401 | \n",
+ " 8.561523 | \n",
+ " 0.000043 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 0 | \n",
+ " tfidf_userknn_K40_cv-0 | \n",
+ " 0.006178 | \n",
+ " 0.037458 | \n",
+ " 0.006957 | \n",
+ " 8.300404 | \n",
+ " 0.000052 | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 0 | \n",
+ " bm25_userknn_K40_cv-0 | \n",
+ " 0.002241 | \n",
+ " 0.011255 | \n",
+ " 0.002210 | \n",
+ " 9.675533 | \n",
+ " 0.000081 | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 1 | \n",
+ " cosine_userknn_K30_cv-0 | \n",
+ " 0.003505 | \n",
+ " 0.020002 | \n",
+ " 0.003580 | \n",
+ " 8.398248 | \n",
+ " 0.000046 | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 1 | \n",
+ " tfidf_userknn_K30_cv-0 | \n",
+ " 0.006328 | \n",
+ " 0.036844 | \n",
+ " 0.007022 | \n",
+ " 8.240133 | \n",
+ " 0.000058 | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 1 | \n",
+ " bm25_userknn_K30_cv-0 | \n",
+ " 0.002722 | \n",
+ " 0.013856 | \n",
+ " 0.002658 | \n",
+ " 9.484692 | \n",
+ " 0.000088 | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 1 | \n",
+ " cosine_userknn_K40_cv-0 | \n",
+ " 0.003245 | \n",
+ " 0.018368 | \n",
+ " 0.003305 | \n",
+ " 8.626906 | \n",
+ " 0.000047 | \n",
+ "
\n",
+ " \n",
+ " | 10 | \n",
+ " 1 | \n",
+ " tfidf_userknn_K40_cv-0 | \n",
+ " 0.006150 | \n",
+ " 0.035964 | \n",
+ " 0.006916 | \n",
+ " 8.377988 | \n",
+ " 0.000061 | \n",
+ "
\n",
+ " \n",
+ " | 11 | \n",
+ " 1 | \n",
+ " bm25_userknn_K40_cv-0 | \n",
+ " 0.002406 | \n",
+ " 0.012067 | \n",
+ " 0.002393 | \n",
+ " 9.756458 | \n",
+ " 0.000086 | \n",
+ "
\n",
+ " \n",
+ " | 12 | \n",
+ " 2 | \n",
+ " cosine_userknn_K30_cv-0 | \n",
+ " 0.003261 | \n",
+ " 0.018498 | \n",
+ " 0.003295 | \n",
+ " 8.439263 | \n",
+ " 0.000047 | \n",
+ "
\n",
+ " \n",
+ " | 13 | \n",
+ " 2 | \n",
+ " tfidf_userknn_K30_cv-0 | \n",
+ " 0.005940 | \n",
+ " 0.034233 | \n",
+ " 0.006479 | \n",
+ " 8.262367 | \n",
+ " 0.000059 | \n",
+ "
\n",
+ " \n",
+ " | 14 | \n",
+ " 2 | \n",
+ " bm25_userknn_K30_cv-0 | \n",
+ " 0.002720 | \n",
+ " 0.013422 | \n",
+ " 0.002530 | \n",
+ " 9.535631 | \n",
+ " 0.000091 | \n",
+ "
\n",
+ " \n",
+ " | 15 | \n",
+ " 2 | \n",
+ " cosine_userknn_K40_cv-0 | \n",
+ " 0.003045 | \n",
+ " 0.017086 | \n",
+ " 0.003100 | \n",
+ " 8.661585 | \n",
+ " 0.000050 | \n",
+ "
\n",
+ " \n",
+ " | 16 | \n",
+ " 2 | \n",
+ " tfidf_userknn_K40_cv-0 | \n",
+ " 0.005914 | \n",
+ " 0.034071 | \n",
+ " 0.006439 | \n",
+ " 8.396618 | \n",
+ " 0.000063 | \n",
+ "
\n",
+ " \n",
+ " | 17 | \n",
+ " 2 | \n",
+ " bm25_userknn_K40_cv-0 | \n",
+ " 0.002404 | \n",
+ " 0.011638 | \n",
+ " 0.002231 | \n",
+ " 9.799119 | \n",
+ " 0.000090 | \n",
+ "
\n",
+ " \n",
+ " | 18 | \n",
+ " 3 | \n",
+ " cosine_userknn_K30_cv-0 | \n",
+ " 0.003277 | \n",
+ " 0.018786 | \n",
+ " 0.003395 | \n",
+ " 8.444986 | \n",
+ " 0.000045 | \n",
+ "
\n",
+ " \n",
+ " | 19 | \n",
+ " 3 | \n",
+ " tfidf_userknn_K30_cv-0 | \n",
+ " 0.006023 | \n",
+ " 0.034171 | \n",
+ " 0.006328 | \n",
+ " 8.276503 | \n",
+ " 0.000059 | \n",
+ "
\n",
+ " \n",
+ " | 20 | \n",
+ " 3 | \n",
+ " bm25_userknn_K30_cv-0 | \n",
+ " 0.002620 | \n",
+ " 0.012762 | \n",
+ " 0.002497 | \n",
+ " 9.560984 | \n",
+ " 0.000091 | \n",
+ "
\n",
+ " \n",
+ " | 21 | \n",
+ " 3 | \n",
+ " cosine_userknn_K40_cv-0 | \n",
+ " 0.003076 | \n",
+ " 0.017512 | \n",
+ " 0.003173 | \n",
+ " 8.658150 | \n",
+ " 0.000045 | \n",
+ "
\n",
+ " \n",
+ " | 22 | \n",
+ " 3 | \n",
+ " tfidf_userknn_K40_cv-0 | \n",
+ " 0.005919 | \n",
+ " 0.033368 | \n",
+ " 0.006253 | \n",
+ " 8.399169 | \n",
+ " 0.000062 | \n",
+ "
\n",
+ " \n",
+ " | 23 | \n",
+ " 3 | \n",
+ " bm25_userknn_K40_cv-0 | \n",
+ " 0.002337 | \n",
+ " 0.011273 | \n",
+ " 0.002253 | \n",
+ " 9.816325 | \n",
+ " 0.000089 | \n",
+ "
\n",
+ " \n",
+ " | 24 | \n",
+ " 4 | \n",
+ " cosine_userknn_K30_cv-0 | \n",
+ " 0.003118 | \n",
+ " 0.018064 | \n",
+ " 0.003157 | \n",
+ " 8.485899 | \n",
+ " 0.000042 | \n",
+ "
\n",
+ " \n",
+ " | 25 | \n",
+ " 4 | \n",
+ " tfidf_userknn_K30_cv-0 | \n",
+ " 0.005911 | \n",
+ " 0.033626 | \n",
+ " 0.006396 | \n",
+ " 8.282428 | \n",
+ " 0.000059 | \n",
+ "
\n",
+ " \n",
+ " | 26 | \n",
+ " 4 | \n",
+ " bm25_userknn_K30_cv-0 | \n",
+ " 0.002537 | \n",
+ " 0.012368 | \n",
+ " 0.002470 | \n",
+ " 9.599645 | \n",
+ " 0.000086 | \n",
+ "
\n",
+ " \n",
+ " | 27 | \n",
+ " 4 | \n",
+ " cosine_userknn_K40_cv-0 | \n",
+ " 0.002872 | \n",
+ " 0.016509 | \n",
+ " 0.002883 | \n",
+ " 8.711984 | \n",
+ " 0.000043 | \n",
+ "
\n",
+ " \n",
+ " | 28 | \n",
+ " 4 | \n",
+ " tfidf_userknn_K40_cv-0 | \n",
+ " 0.005793 | \n",
+ " 0.033028 | \n",
+ " 0.006261 | \n",
+ " 8.416680 | \n",
+ " 0.000062 | \n",
+ "
\n",
+ " \n",
+ " | 29 | \n",
+ " 4 | \n",
+ " bm25_userknn_K40_cv-0 | \n",
+ " 0.002213 | \n",
+ " 0.010860 | \n",
+ " 0.002179 | \n",
+ " 9.866201 | \n",
+ " 0.000085 | \n",
+ "
\n",
+ " \n",
+ " | 30 | \n",
+ " 5 | \n",
+ " cosine_userknn_K30_cv-0 | \n",
+ " 0.003003 | \n",
+ " 0.016252 | \n",
+ " 0.002899 | \n",
+ " 8.498968 | \n",
+ " 0.000043 | \n",
+ "
\n",
+ " \n",
+ " | 31 | \n",
+ " 5 | \n",
+ " tfidf_userknn_K30_cv-0 | \n",
+ " 0.005527 | \n",
+ " 0.030942 | \n",
+ " 0.005823 | \n",
+ " 8.325273 | \n",
+ " 0.000057 | \n",
+ "
\n",
+ " \n",
+ " | 32 | \n",
+ " 5 | \n",
+ " bm25_userknn_K30_cv-0 | \n",
+ " 0.002597 | \n",
+ " 0.012263 | \n",
+ " 0.002386 | \n",
+ " 9.646957 | \n",
+ " 0.000100 | \n",
+ "
\n",
+ " \n",
+ " | 33 | \n",
+ " 5 | \n",
+ " cosine_userknn_K40_cv-0 | \n",
+ " 0.002765 | \n",
+ " 0.014713 | \n",
+ " 0.002661 | \n",
+ " 8.717559 | \n",
+ " 0.000047 | \n",
+ "
\n",
+ " \n",
+ " | 34 | \n",
+ " 5 | \n",
+ " tfidf_userknn_K40_cv-0 | \n",
+ " 0.005545 | \n",
+ " 0.030892 | \n",
+ " 0.005817 | \n",
+ " 8.454091 | \n",
+ " 0.000059 | \n",
+ "
\n",
+ " \n",
+ " | 35 | \n",
+ " 5 | \n",
+ " bm25_userknn_K40_cv-0 | \n",
+ " 0.002302 | \n",
+ " 0.010777 | \n",
+ " 0.002135 | \n",
+ " 9.914042 | \n",
+ " 0.000100 | \n",
+ "
\n",
+ " \n",
+ " | 36 | \n",
+ " 6 | \n",
+ " cosine_userknn_K30_cv-0 | \n",
+ " 0.002963 | \n",
+ " 0.016532 | \n",
+ " 0.002887 | \n",
+ " 8.563809 | \n",
+ " 0.000050 | \n",
+ "
\n",
+ " \n",
+ " | 37 | \n",
+ " 6 | \n",
+ " tfidf_userknn_K30_cv-0 | \n",
+ " 0.005330 | \n",
+ " 0.030717 | \n",
+ " 0.005763 | \n",
+ " 8.366259 | \n",
+ " 0.000064 | \n",
+ "
\n",
+ " \n",
+ " | 38 | \n",
+ " 6 | \n",
+ " bm25_userknn_K30_cv-0 | \n",
+ " 0.002571 | \n",
+ " 0.012691 | \n",
+ " 0.002478 | \n",
+ " 9.715097 | \n",
+ " 0.000100 | \n",
+ "
\n",
+ " \n",
+ " | 39 | \n",
+ " 6 | \n",
+ " cosine_userknn_K40_cv-0 | \n",
+ " 0.002769 | \n",
+ " 0.015448 | \n",
+ " 0.002675 | \n",
+ " 8.775058 | \n",
+ " 0.000051 | \n",
+ "
\n",
+ " \n",
+ " | 40 | \n",
+ " 6 | \n",
+ " tfidf_userknn_K40_cv-0 | \n",
+ " 0.005284 | \n",
+ " 0.030418 | \n",
+ " 0.005697 | \n",
+ " 8.488473 | \n",
+ " 0.000066 | \n",
+ "
\n",
+ " \n",
+ " | 41 | \n",
+ " 6 | \n",
+ " bm25_userknn_K40_cv-0 | \n",
+ " 0.002340 | \n",
+ " 0.011278 | \n",
+ " 0.002208 | \n",
+ " 9.964664 | \n",
+ " 0.000099 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " fold model prec@10 recall@10 MAP@10 novelty \\\n",
+ "0 0 cosine_userknn_K30_cv-0 0.003557 0.021128 0.003695 8.331491 \n",
+ "1 0 tfidf_userknn_K30_cv-0 0.006439 0.039102 0.007335 8.155051 \n",
+ "2 0 bm25_userknn_K30_cv-0 0.002593 0.013494 0.002531 9.398467 \n",
+ "3 0 cosine_userknn_K40_cv-0 0.003282 0.019323 0.003401 8.561523 \n",
+ "4 0 tfidf_userknn_K40_cv-0 0.006178 0.037458 0.006957 8.300404 \n",
+ "5 0 bm25_userknn_K40_cv-0 0.002241 0.011255 0.002210 9.675533 \n",
+ "6 1 cosine_userknn_K30_cv-0 0.003505 0.020002 0.003580 8.398248 \n",
+ "7 1 tfidf_userknn_K30_cv-0 0.006328 0.036844 0.007022 8.240133 \n",
+ "8 1 bm25_userknn_K30_cv-0 0.002722 0.013856 0.002658 9.484692 \n",
+ "9 1 cosine_userknn_K40_cv-0 0.003245 0.018368 0.003305 8.626906 \n",
+ "10 1 tfidf_userknn_K40_cv-0 0.006150 0.035964 0.006916 8.377988 \n",
+ "11 1 bm25_userknn_K40_cv-0 0.002406 0.012067 0.002393 9.756458 \n",
+ "12 2 cosine_userknn_K30_cv-0 0.003261 0.018498 0.003295 8.439263 \n",
+ "13 2 tfidf_userknn_K30_cv-0 0.005940 0.034233 0.006479 8.262367 \n",
+ "14 2 bm25_userknn_K30_cv-0 0.002720 0.013422 0.002530 9.535631 \n",
+ "15 2 cosine_userknn_K40_cv-0 0.003045 0.017086 0.003100 8.661585 \n",
+ "16 2 tfidf_userknn_K40_cv-0 0.005914 0.034071 0.006439 8.396618 \n",
+ "17 2 bm25_userknn_K40_cv-0 0.002404 0.011638 0.002231 9.799119 \n",
+ "18 3 cosine_userknn_K30_cv-0 0.003277 0.018786 0.003395 8.444986 \n",
+ "19 3 tfidf_userknn_K30_cv-0 0.006023 0.034171 0.006328 8.276503 \n",
+ "20 3 bm25_userknn_K30_cv-0 0.002620 0.012762 0.002497 9.560984 \n",
+ "21 3 cosine_userknn_K40_cv-0 0.003076 0.017512 0.003173 8.658150 \n",
+ "22 3 tfidf_userknn_K40_cv-0 0.005919 0.033368 0.006253 8.399169 \n",
+ "23 3 bm25_userknn_K40_cv-0 0.002337 0.011273 0.002253 9.816325 \n",
+ "24 4 cosine_userknn_K30_cv-0 0.003118 0.018064 0.003157 8.485899 \n",
+ "25 4 tfidf_userknn_K30_cv-0 0.005911 0.033626 0.006396 8.282428 \n",
+ "26 4 bm25_userknn_K30_cv-0 0.002537 0.012368 0.002470 9.599645 \n",
+ "27 4 cosine_userknn_K40_cv-0 0.002872 0.016509 0.002883 8.711984 \n",
+ "28 4 tfidf_userknn_K40_cv-0 0.005793 0.033028 0.006261 8.416680 \n",
+ "29 4 bm25_userknn_K40_cv-0 0.002213 0.010860 0.002179 9.866201 \n",
+ "30 5 cosine_userknn_K30_cv-0 0.003003 0.016252 0.002899 8.498968 \n",
+ "31 5 tfidf_userknn_K30_cv-0 0.005527 0.030942 0.005823 8.325273 \n",
+ "32 5 bm25_userknn_K30_cv-0 0.002597 0.012263 0.002386 9.646957 \n",
+ "33 5 cosine_userknn_K40_cv-0 0.002765 0.014713 0.002661 8.717559 \n",
+ "34 5 tfidf_userknn_K40_cv-0 0.005545 0.030892 0.005817 8.454091 \n",
+ "35 5 bm25_userknn_K40_cv-0 0.002302 0.010777 0.002135 9.914042 \n",
+ "36 6 cosine_userknn_K30_cv-0 0.002963 0.016532 0.002887 8.563809 \n",
+ "37 6 tfidf_userknn_K30_cv-0 0.005330 0.030717 0.005763 8.366259 \n",
+ "38 6 bm25_userknn_K30_cv-0 0.002571 0.012691 0.002478 9.715097 \n",
+ "39 6 cosine_userknn_K40_cv-0 0.002769 0.015448 0.002675 8.775058 \n",
+ "40 6 tfidf_userknn_K40_cv-0 0.005284 0.030418 0.005697 8.488473 \n",
+ "41 6 bm25_userknn_K40_cv-0 0.002340 0.011278 0.002208 9.964664 \n",
+ "\n",
+ " serendipity \n",
+ "0 0.000040 \n",
+ "1 0.000048 \n",
+ "2 0.000081 \n",
+ "3 0.000043 \n",
+ "4 0.000052 \n",
+ "5 0.000081 \n",
+ "6 0.000046 \n",
+ "7 0.000058 \n",
+ "8 0.000088 \n",
+ "9 0.000047 \n",
+ "10 0.000061 \n",
+ "11 0.000086 \n",
+ "12 0.000047 \n",
+ "13 0.000059 \n",
+ "14 0.000091 \n",
+ "15 0.000050 \n",
+ "16 0.000063 \n",
+ "17 0.000090 \n",
+ "18 0.000045 \n",
+ "19 0.000059 \n",
+ "20 0.000091 \n",
+ "21 0.000045 \n",
+ "22 0.000062 \n",
+ "23 0.000089 \n",
+ "24 0.000042 \n",
+ "25 0.000059 \n",
+ "26 0.000086 \n",
+ "27 0.000043 \n",
+ "28 0.000062 \n",
+ "29 0.000085 \n",
+ "30 0.000043 \n",
+ "31 0.000057 \n",
+ "32 0.000100 \n",
+ "33 0.000047 \n",
+ "34 0.000059 \n",
+ "35 0.000100 \n",
+ "36 0.000050 \n",
+ "37 0.000064 \n",
+ "38 0.000100 \n",
+ "39 0.000051 \n",
+ "40 0.000066 \n",
+ "41 0.000099 "
+ ]
+ },
+ "execution_count": 46,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_metrics = pd.DataFrame(results)\n",
+ "df_metrics"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "id": "a0334b9a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_metrics.to_pickle(\"../data/hw_3/df_metrics.pickle\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "446530ce",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " fold | \n",
+ " prec@10 | \n",
+ " recall@10 | \n",
+ " MAP@10 | \n",
+ " novelty | \n",
+ " serendipity | \n",
+ "
\n",
+ " \n",
+ " | model | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | bm25_userknn_K30_cv-0 | \n",
+ " 3.0 | \n",
+ " 0.002623 | \n",
+ " 0.012980 | \n",
+ " 0.002507 | \n",
+ " 9.563068 | \n",
+ " 0.000091 | \n",
+ "
\n",
+ " \n",
+ " | bm25_userknn_K40_cv-0 | \n",
+ " 3.0 | \n",
+ " 0.002320 | \n",
+ " 0.011307 | \n",
+ " 0.002230 | \n",
+ " 9.827477 | \n",
+ " 0.000090 | \n",
+ "
\n",
+ " \n",
+ " | cosine_userknn_K30_cv-0 | \n",
+ " 3.0 | \n",
+ " 0.003241 | \n",
+ " 0.018466 | \n",
+ " 0.003272 | \n",
+ " 8.451809 | \n",
+ " 0.000045 | \n",
+ "
\n",
+ " \n",
+ " | cosine_userknn_K40_cv-0 | \n",
+ " 3.0 | \n",
+ " 0.003008 | \n",
+ " 0.016994 | \n",
+ " 0.003028 | \n",
+ " 8.673252 | \n",
+ " 0.000047 | \n",
+ "
\n",
+ " \n",
+ " | tfidf_userknn_K30_cv-0 | \n",
+ " 3.0 | \n",
+ " 0.005928 | \n",
+ " 0.034234 | \n",
+ " 0.006449 | \n",
+ " 8.272573 | \n",
+ " 0.000058 | \n",
+ "
\n",
+ " \n",
+ " | tfidf_userknn_K40_cv-0 | \n",
+ " 3.0 | \n",
+ " 0.005826 | \n",
+ " 0.033600 | \n",
+ " 0.006334 | \n",
+ " 8.404775 | \n",
+ " 0.000061 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " fold prec@10 recall@10 MAP@10 novelty \\\n",
+ "model \n",
+ "bm25_userknn_K30_cv-0 3.0 0.002623 0.012980 0.002507 9.563068 \n",
+ "bm25_userknn_K40_cv-0 3.0 0.002320 0.011307 0.002230 9.827477 \n",
+ "cosine_userknn_K30_cv-0 3.0 0.003241 0.018466 0.003272 8.451809 \n",
+ "cosine_userknn_K40_cv-0 3.0 0.003008 0.016994 0.003028 8.673252 \n",
+ "tfidf_userknn_K30_cv-0 3.0 0.005928 0.034234 0.006449 8.272573 \n",
+ "tfidf_userknn_K40_cv-0 3.0 0.005826 0.033600 0.006334 8.404775 \n",
+ "\n",
+ " serendipity \n",
+ "model \n",
+ "bm25_userknn_K30_cv-0 0.000091 \n",
+ "bm25_userknn_K40_cv-0 0.000090 \n",
+ "cosine_userknn_K30_cv-0 0.000045 \n",
+ "cosine_userknn_K40_cv-0 0.000047 \n",
+ "tfidf_userknn_K30_cv-0 0.000058 \n",
+ "tfidf_userknn_K40_cv-0 0.000061 "
+ ]
+ },
+ "execution_count": 48,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_metrics.groupby('model').mean()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "5fb9ba9f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " prec@10 | \n",
+ " recall@10 | \n",
+ " MAP@10 | \n",
+ " novelty | \n",
+ " serendipity | \n",
+ "
\n",
+ " \n",
+ " | model | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | bm25_userknn_K30_cv-0 | \n",
+ " 0.000072 | \n",
+ " 0.000612 | \n",
+ " 0.000083 | \n",
+ " 0.104468 | \n",
+ " 0.000007 | \n",
+ "
\n",
+ " \n",
+ " | bm25_userknn_K40_cv-0 | \n",
+ " 0.000074 | \n",
+ " 0.000442 | \n",
+ " 0.000081 | \n",
+ " 0.097359 | \n",
+ " 0.000007 | \n",
+ "
\n",
+ " \n",
+ " | cosine_userknn_K30_cv-0 | \n",
+ " 0.000231 | \n",
+ " 0.001749 | \n",
+ " 0.000314 | \n",
+ " 0.074699 | \n",
+ " 0.000003 | \n",
+ "
\n",
+ " \n",
+ " | cosine_userknn_K40_cv-0 | \n",
+ " 0.000213 | \n",
+ " 0.001603 | \n",
+ " 0.000295 | \n",
+ " 0.069310 | \n",
+ " 0.000003 | \n",
+ "
\n",
+ " \n",
+ " | tfidf_userknn_K30_cv-0 | \n",
+ " 0.000398 | \n",
+ " 0.003003 | \n",
+ " 0.000577 | \n",
+ " 0.066627 | \n",
+ " 0.000005 | \n",
+ "
\n",
+ " \n",
+ " | tfidf_userknn_K40_cv-0 | \n",
+ " 0.000321 | \n",
+ " 0.002534 | \n",
+ " 0.000487 | \n",
+ " 0.059565 | \n",
+ " 0.000004 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " prec@10 recall@10 MAP@10 novelty serendipity\n",
+ "model \n",
+ "bm25_userknn_K30_cv-0 0.000072 0.000612 0.000083 0.104468 0.000007\n",
+ "bm25_userknn_K40_cv-0 0.000074 0.000442 0.000081 0.097359 0.000007\n",
+ "cosine_userknn_K30_cv-0 0.000231 0.001749 0.000314 0.074699 0.000003\n",
+ "cosine_userknn_K40_cv-0 0.000213 0.001603 0.000295 0.069310 0.000003\n",
+ "tfidf_userknn_K30_cv-0 0.000398 0.003003 0.000577 0.066627 0.000005\n",
+ "tfidf_userknn_K40_cv-0 0.000321 0.002534 0.000487 0.059565 0.000004"
+ ]
+ },
+ "execution_count": 49,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_metrics.groupby('model').std()[metrics.keys()]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "41828ee5",
+ "metadata": {},
+ "source": [
+ "по **ofline** метрикам лучше всего себя показывает модель TFIDFRecommender\n",
+ "TFIDFRecommender подбор К"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7a8a0a41",
+ "metadata": {},
+ "source": [
+ "# Подбор оптимального K для TFIDFRecommender"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "1e91892d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'tfidf_userknn_K50': ,\n",
+ " 'tfidf_userknn_K60': ,\n",
+ " 'tfidf_userknn_K70': }"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "N_USERS = 50\n",
+ "\n",
+ "# Т.к. метрики для К 30 и 40 уже есть\n",
+ "K = [k for k in range(50, 71, 10)]\n",
+ "models = dict()\n",
+ "\n",
+ "for k in K:\n",
+ " models[f\"tfidf_userknn_K{k}\"] = TFIDFRecommender(K=k)\n",
+ "models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e7c2c43b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "==================== Fold 0\n",
+ "{'End date': Timestamp('2021-07-18 00:00:00', freq='5D'),\n",
+ " 'Start date': Timestamp('2021-07-13 00:00:00', freq='5D'),\n",
+ " 'Test': 156580,\n",
+ " 'Test items': 5793,\n",
+ " 'Test users': 68150,\n",
+ " 'Train': 3281612,\n",
+ " 'Train items': 14754,\n",
+ " 'Train users': 652905}\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "211234f034a54bae86b94dff33b9f5c4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/652905 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "\n",
+ "results_idf = []\n",
+ "\n",
+ "fold_iterator = cv_v1.split(Interactions(interactions), collect_fold_stats=True)\n",
+ "\n",
+ "for i_fold, (train_ids, test_ids, fold_info) in enumerate(fold_iterator):\n",
+ " print(f\"\\n==================== Fold {i_fold}\")\n",
+ " pprint(fold_info)\n",
+ "\n",
+ " df_train = interactions.iloc[train_ids].copy()\n",
+ " df_test = interactions.iloc[test_ids][Columns.UserItem].copy()\n",
+ "\n",
+ " catalog = df_train[Columns.Item].unique()\n",
+ "\n",
+ " for model_name, model in models.items():\n",
+ " userknn_model = UserKnn(model=model, N_users=N_USERS)\n",
+ " userknn_model.fit(df_train)\n",
+ " recos = userknn_model.predict(df_test)\n",
+ "\n",
+ " metric_values = calc_metrics(\n",
+ " metrics,\n",
+ " reco=recos,\n",
+ " interactions=df_test,\n",
+ " prev_interactions=df_train,\n",
+ " catalog=catalog,\n",
+ " )\n",
+ "\n",
+ " full_model_name = f\"{model_name}\"\n",
+ " fold = {\"fold\": i_fold, \"model\": full_model_name}\n",
+ " fold.update(metric_values)\n",
+ " results_idf.append(fold)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "37a896aa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df_metrics_tfidf = pd.DataFrame(results_idf)\n",
+ "df_metrics_tfidf"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "073517d2",
+ "metadata": {},
+ "source": [
+ "# Train TFIDFRecommender on all data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "13ca7a8b",
+ "metadata": {},
+ "source": [
+ "Обучение TFIDFRecommender на всём объеме данных"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "278879e0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "cc8fbbecd80145dea10421894ede12e0",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/962179 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 2h 27min 37s, sys: 25.2 s, total: 2h 28min 2s\n",
+ "Wall time: 26min 59s\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "\n",
+ "results = []\n",
+ "\n",
+ "df_train = interactions.copy()\n",
+ "catalog = df_train[Columns.Item].unique()\n",
+ "\n",
+ "tfidf_model = TFIDFRecommender(K=30)\n",
+ "userknn_model = UserKnn(model=tfidf_model, N_users=50, use_weight_idf=True)\n",
+ "userknn_model.fit(df_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "49f8bf89",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import dill\n",
+ "\n",
+ "with open('../service/weights/userKNN/userknn_tfidf_k30.dill', 'wb') as f:\n",
+ " dill.dump(userknn_model.user_knn, f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "fd4b5830",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "user_id 176549\n",
+ "item_id 9506\n",
+ "datetime 2021-05-11 00:00:00\n",
+ "weight 4250\n",
+ "watched_pct 72.0\n",
+ "Name: 0, dtype: object"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_train.iloc[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "9408bc58",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "False"
+ ]
+ },
+ "execution_count": 38,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "555555555 in df_train[Columns.User].tolist()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "f614ccaf",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " last_watch_dt | \n",
+ " total_dur | \n",
+ " watched_pct | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 176549 | \n",
+ " 9506 | \n",
+ " 2021-05-11 | \n",
+ " 4250 | \n",
+ " 72.0 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 699317 | \n",
+ " 1659 | \n",
+ " 2021-05-29 | \n",
+ " 8317 | \n",
+ " 100.0 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 656683 | \n",
+ " 7107 | \n",
+ " 2021-05-09 | \n",
+ " 10 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 864613 | \n",
+ " 7638 | \n",
+ " 2021-07-05 | \n",
+ " 14483 | \n",
+ " 100.0 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 964868 | \n",
+ " 9506 | \n",
+ " 2021-04-30 | \n",
+ " 6725 | \n",
+ " 100.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476246 | \n",
+ " 648596 | \n",
+ " 12225 | \n",
+ " 2021-08-13 | \n",
+ " 76 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476247 | \n",
+ " 546862 | \n",
+ " 9673 | \n",
+ " 2021-04-13 | \n",
+ " 2308 | \n",
+ " 49.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476248 | \n",
+ " 697262 | \n",
+ " 15297 | \n",
+ " 2021-08-20 | \n",
+ " 18307 | \n",
+ " 63.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476249 | \n",
+ " 384202 | \n",
+ " 16197 | \n",
+ " 2021-04-19 | \n",
+ " 6203 | \n",
+ " 100.0 | \n",
+ "
\n",
+ " \n",
+ " | 5476250 | \n",
+ " 319709 | \n",
+ " 4436 | \n",
+ " 2021-08-15 | \n",
+ " 3921 | \n",
+ " 45.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id item_id last_watch_dt total_dur watched_pct\n",
+ "0 176549 9506 2021-05-11 4250 72.0\n",
+ "1 699317 1659 2021-05-29 8317 100.0\n",
+ "2 656683 7107 2021-05-09 10 0.0\n",
+ "3 864613 7638 2021-07-05 14483 100.0\n",
+ "4 964868 9506 2021-04-30 6725 100.0\n",
+ "5476246 648596 12225 2021-08-13 76 0.0\n",
+ "5476247 546862 9673 2021-04-13 2308 49.0\n",
+ "5476248 697262 15297 2021-08-20 18307 63.0\n",
+ "5476249 384202 16197 2021-04-19 6203 100.0\n",
+ "5476250 319709 4436 2021-08-15 3921 45.0"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.concat([interactions.head(), interactions.tail()])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "dc4d9fd7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(962179,)"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "interactions['user_id'].unique().shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "b7861d19",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[(961833, 1.0),\n",
+ " (961849, 1.0),\n",
+ " (961857, 1.0),\n",
+ " (961871, 1.0),\n",
+ " (961873, 1.0),\n",
+ " (961876, 1.0),\n",
+ " (961887, 1.0),\n",
+ " (961907, 1.0),\n",
+ " (961910, 1.0),\n",
+ " (961912, 1.0)]"
+ ]
+ },
+ "execution_count": 35,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import dill\n",
+ "\n",
+ "with open('../service/weights/userKNN/userknn_tfidf_k30.dill', 'rb') as f:\n",
+ " userknn = dill.load(f)\n",
+ "\n",
+ "userknn.similar_items(962178, 10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1905033a",
+ "metadata": {},
+ "source": [
+ "# Popular Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "2df74dba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from rectools.models import PopularModel\n",
+ "from rectools.dataset import Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "6ba37a73",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Timestamp('2021-08-22 00:00:00')"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "max_date = interactions[Columns.Datetime].max().normalize()\n",
+ "max_date"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "901353f9",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "train = interactions[[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime]][\n",
+ " interactions[Columns.Datetime] < max_date - pd.Timedelta(5, \"D\")]\n",
+ "\n",
+ "test = interactions[[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime]][\n",
+ " interactions[Columns.Datetime] >= max_date - pd.Timedelta(5, \"D\")]\n",
+ "\n",
+ "dataset_train = Dataset.construct(train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 144,
+ "id": "f08e3579",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "popilarity_models = {\n",
+ " \"popular\": PopularModel(),\n",
+ " \"popular_mw\": PopularModel(popularity=\"mean_weight\")\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 145,
+ "id": "03c3bfb6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "popilarity_models[\"popular\"].fit(dataset_train)\n",
+ "popilarity_models[\"popular_mw\"].fit(dataset_train);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 146,
+ "id": "0d7de49e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([ 24, 20, 31, 15, 167, 81, 89, 135, 355, 116])"
+ ]
+ },
+ "execution_count": 146,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "popilarity_models[\"popular\"].popularity_list[0][:10]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 147,
+ "id": "05ff208d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([11363, 11681, 12841, 13017, 2069, 13691, 13552, 13397, 11774,\n",
+ " 12913])"
+ ]
+ },
+ "execution_count": 147,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "popilarity_models[\"popular_mw\"].popularity_list[0][:10]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 148,
+ "id": "00ef735c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pecos_pop = popilarity_models[\"popular\"].recommend(\n",
+ " users=test[Columns.User].unique(),\n",
+ " dataset=dataset,\n",
+ " k=100,\n",
+ " filter_viewed=False,\n",
+ ")\n",
+ "\n",
+ "pecos_pop_mw = popilarity_models[\"popular_mw\"].recommend(\n",
+ " users=test[Columns.User].unique(),\n",
+ " dataset=dataset,\n",
+ " k=100,\n",
+ " filter_viewed=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 152,
+ "id": "b302db55",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "metrics = {\n",
+ " \"prec@5\": Precision(k=5),\n",
+ " \"recall@5\": Recall(k=5),\n",
+ " \"MAP@5\": MAP(k=5),\n",
+ " \"prec@10\": Precision(k=10),\n",
+ " \"recall@10\": Recall(k=10),\n",
+ " \"MAP@20\": MAP(k=20),\n",
+ " \"prec@20\": Precision(k=20),\n",
+ " \"recall@20\": Recall(k=20),\n",
+ " \"MAP@100\": MAP(k=100),\n",
+ " \"prec@100\": Precision(k=100),\n",
+ " \"recall@100\": Recall(k=100),\n",
+ " \"MAP@100\": MAP(k=100),\n",
+ " \"novelty\": MeanInvUserFreq(k=10),\n",
+ " \"serendipity\": Serendipity(k=10),\n",
+ "}\n",
+ "catalog = train[Columns.Item].unique()\n",
+ "metric_values_pop = calc_metrics(metrics, pecos_pop, test, train, catalog)\n",
+ "metric_values_pop_mean_weight = calc_metrics(metrics, pecos_pop_mw, test, train, catalog)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 153,
+ "id": "9631093b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'prec@5': 0.0017855613317256697,\n",
+ " 'recall@5': 0.004623809755660008,\n",
+ " 'prec@10': 0.0011648975773029461,\n",
+ " 'recall@10': 0.005682095875283048,\n",
+ " 'prec@20': 0.0010502526799891945,\n",
+ " 'recall@20': 0.00880186008464912,\n",
+ " 'prec@100': 0.003247020220987923,\n",
+ " 'recall@100': 0.16609031082955295,\n",
+ " 'MAP@5': 0.0013179725619140792,\n",
+ " 'MAP@20': 0.0016695313583723814,\n",
+ " 'MAP@100': 0.005578924867474493,\n",
+ " 'novelty': 9.976033936531364,\n",
+ " 'serendipity': 1.2752762676592953e-05}"
+ ]
+ },
+ "execution_count": 153,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "metric_values_pop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 154,
+ "id": "5d55b781",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'prec@5': 9.09252633867684e-05,\n",
+ " 'recall@5': 0.00014799438063171262,\n",
+ " 'prec@10': 4.612151041357817e-05,\n",
+ " 'recall@10': 0.00015458316783365238,\n",
+ " 'prec@20': 2.635514880775895e-05,\n",
+ " 'recall@20': 0.00016946607539568094,\n",
+ " 'prec@100': 0.00015147621777259455,\n",
+ " 'recall@100': 0.0065476971391510656,\n",
+ " 'MAP@5': 3.0257754846536496e-05,\n",
+ " 'MAP@20': 3.1771198360212185e-05,\n",
+ " 'MAP@100': 0.00011355765992119742,\n",
+ " 'novelty': 17.423655787689828,\n",
+ " 'serendipity': 1.8991632826477633e-06}"
+ ]
+ },
+ "execution_count": 154,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "metric_values_pop_mean_weight"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e5a4a011",
+ "metadata": {},
+ "source": [
+ "**На офлайн метриках выигрывает обычная модель по популярному**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5875fab7",
+ "metadata": {},
+ "source": [
+ "# Save item_idf data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6589996f",
+ "metadata": {},
+ "source": [
+ "Создаем датасет со взвешенными item-ами по механизму idf для использования в будущем"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "d62cabb9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " index | \n",
+ " idf | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 9506 | \n",
+ " 7.150811 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 1659 | \n",
+ " 8.524953 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 7107 | \n",
+ " 5.821207 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 7638 | \n",
+ " 8.407093 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 6686 | \n",
+ " 7.778734 | \n",
+ "
\n",
+ " \n",
+ " | ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " | 15701 | \n",
+ " 7833 | \n",
+ " 14.822785 | \n",
+ "
\n",
+ " \n",
+ " | 15702 | \n",
+ " 9125 | \n",
+ " 14.822785 | \n",
+ "
\n",
+ " \n",
+ " | 15703 | \n",
+ " 10064 | \n",
+ " 14.822785 | \n",
+ "
\n",
+ " \n",
+ " | 15704 | \n",
+ " 13019 | \n",
+ " 14.822785 | \n",
+ "
\n",
+ " \n",
+ " | 15705 | \n",
+ " 10542 | \n",
+ " 14.822785 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
15706 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " index idf\n",
+ "0 9506 7.150811\n",
+ "1 1659 8.524953\n",
+ "2 7107 5.821207\n",
+ "3 7638 8.407093\n",
+ "4 6686 7.778734\n",
+ "... ... ...\n",
+ "15701 7833 14.822785\n",
+ "15702 9125 14.822785\n",
+ "15703 10064 14.822785\n",
+ "15704 13019 14.822785\n",
+ "15705 10542 14.822785\n",
+ "\n",
+ "[15706 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 40,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "item_cnt = Counter(interactions['item_id'].values)\n",
+ "item_idf = pd.DataFrame.from_dict(item_cnt, orient='index', columns=['doc_freq']).reset_index()\n",
+ "n = interactions.shape[0]\n",
+ "item_idf['idf'] = item_idf['doc_freq'].apply(lambda x: np.log((1 + n) / (1 + x) + 1))\n",
+ "del item_idf['doc_freq']\n",
+ "item_idf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "7da47dfc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "item_idf = item_idf.sort_values(\"idf\", ascending=False)\n",
+ "item_idf.to_csv('../data/kion_train/items_idf.csv', index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fdce2b60",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/hw_2.ipynb b/hw_2.ipynb
new file mode 100644
index 00000000..b4d81fa6
--- /dev/null
+++ b/hw_2.ipynb
@@ -0,0 +1,440 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "from pprint import pprint\n",
+ "\n",
+ "import copy\n",
+ "\n",
+ "from tqdm.auto import tqdm\n",
+ "\n",
+ "from implicit.nearest_neighbours import TFIDFRecommender, BM25Recommender\n",
+ "from implicit.als import AlternatingLeastSquares\n",
+ "\n",
+ "\n",
+ "from rectools import Columns\n",
+ "from rectools.dataset import Interactions, Dataset\n",
+ "from rectools.metrics import Precision, Recall, MeanInvUserFreq, Serendipity, calc_metrics, MAP, MRR\n",
+ "from rectools.models import ImplicitItemKNNWrapperModel, RandomModel, PopularModel\n",
+ "from rectools.model_selection import TimeRangeSplitter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 123,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = pd.read_csv('data_original/interactions.csv', parse_dates=['last_watch_dt'])\n",
+ "\n",
+ "df.rename(\n",
+ " columns={\n",
+ " 'last_watch_dt': Columns.Datetime,\n",
+ " 'total_dur': Columns.Weight\n",
+ " }, \n",
+ " inplace=True) \n",
+ "\n",
+ "interactions = Interactions(df)\n",
+ "\n",
+ "\n",
+ "users = pd.read_csv('data_original/users.csv')\n",
+ "items = pd.read_csv('data_original/items.csv')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 124,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " age | \n",
+ " income | \n",
+ " sex | \n",
+ " kids_flg | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 373089 | \n",
+ " 666262 | \n",
+ " age_65_inf | \n",
+ " income_20_40 | \n",
+ " Ж | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id age income sex kids_flg\n",
+ "373089 666262 age_65_inf income_20_40 Ж 0"
+ ]
+ },
+ "execution_count": 124,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "users[users['user_id'] == 666262]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Функция для расчета метрик"
+ ]
+ },
+ {
+ "attachments": {
+ "image.png": {
+ "image/png": ""
+ }
+ },
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Модели: rectools.models.RandomModel(random_state=32), rectools.models.PopularModel() с параметрами по умолчанию\n",
+ "models = {\n",
+ " \"random\": RandomModel(random_state=32),\n",
+ " \"popular\": PopularModel()\n",
+ "}\n",
+ "\n",
+ "# Метрики: 2 ранжирующие, 2 классификационные, 2 beyond-accuracy. Считаем по порогам 1, 5, 10. MAP обязательно\n",
+ "metrics = {\n",
+ " # классификационные\n",
+ " \"prec@1\": Precision(k=1),\n",
+ " \"prec@10\": Precision(k=5),\n",
+ " \"prec@10\": Precision(k=10),\n",
+ " \"recall\": Recall(k=1),\n",
+ " \"recall\": Recall(k=5),\n",
+ " \"recall\": Recall(k=10),\n",
+ " # ранжирующие\n",
+ " \"MAP\": MAP(k=1),\n",
+ " \"MAP\": MAP(k=5),\n",
+ " \"MAP\": MAP(k=10),\n",
+ " # среднее значение обратного ранга\n",
+ " \"MRR\": MRR(k=1),\n",
+ " \"MRR\": MRR(k=5),\n",
+ " \"MRR\": MRR(k=10),\n",
+ " \"novelty\": MeanInvUserFreq(k=10),\n",
+ " \"serendipity\": Serendipity(k=10),\n",
+ "}\n",
+ "\n",
+ "# 3 фолда для кросс-валидации по неделе\n",
+ "n_splits = 3\n",
+ "test_size = \"14D\"\n",
+ "\n",
+ "# Инициализированный Splitter для кросс-валидации\n",
+ "cv = TimeRangeSplitter(\n",
+ " test_size= test_size,\n",
+ " n_splits=n_splits,\n",
+ " filter_already_seen=True,\n",
+ " filter_cold_items=True,\n",
+ " filter_cold_users=True,\n",
+ ")\n",
+ "\n",
+ "# Количество рекомендаций для генерации (K)\n",
+ "K_RECOS = 10"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = Dataset.construct(\n",
+ " interactions_df=interactions,\n",
+ " user_features_df=None,\n",
+ " item_features_df=None,\n",
+ " )\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 103,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 5 µs, sys: 0 ns, total: 5 µs\n",
+ "Wall time: 32.2 µs\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "import time\n",
+ "\n",
+ "def evaluate_models(interactions, models, metrics, cv, K_RECOS):\n",
+ " results = []\n",
+ " trained_models = {}\n",
+ "\n",
+ " # n_splits = cv.get_n_splits()\n",
+ " fold_iterator = cv.split(interactions, collect_fold_stats=True)\n",
+ "\n",
+ " for train_ids, test_ids, fold_info in tqdm(fold_iterator, total=n_splits):\n",
+ " print(f\"\\n==================== Fold {fold_info['i_split']}\")\n",
+ " pprint(fold_info)\n",
+ "\n",
+ " df_train = interactions.df.iloc[train_ids]\n",
+ " # Создаем RecTools Dataset через метод construct на train взаимодействиях для каждого фолда\n",
+ " dataset = Dataset.construct(df_train)\n",
+ " # Определили test\n",
+ " df_test = interactions.df.iloc[test_ids] # Предполагается, что Columns.UserItem определено\n",
+ " test_users = np.unique(df_test[Columns.User])\n",
+ "\n",
+ " catalog = df_train[Columns.Item].unique() # Каталог для рекомендаций\n",
+ "\n",
+ " # Обучаем модель (не забываем сделать deepcopy), рекоменуем K айтемов для каждого юзера, считаем метрики на test\n",
+ " for model_name, model in models.items():\n",
+ " \n",
+ " model_copy = copy.deepcopy(model)\n",
+ " # время перед началом обучения\n",
+ " start_time = time.time()\n",
+ " model.fit(dataset)\n",
+ " recos = model.recommend(\n",
+ " users=test_users,\n",
+ " dataset=dataset,\n",
+ " k=K_RECOS,\n",
+ " filter_viewed=True,\n",
+ " )\n",
+ " metric_values = calc_metrics(\n",
+ " metrics,\n",
+ " reco=recos,\n",
+ " interactions=df_test,\n",
+ " prev_interactions=df_train,\n",
+ " catalog=catalog,\n",
+ " )\n",
+ " \n",
+ " # время обучения\n",
+ " training_time = time.time() - start_time\n",
+ "\n",
+ " res = {\"fold\": fold_info[\"i_split\"], \"model\": model_name, \"training_time\": training_time}\n",
+ " res.update(metric_values)\n",
+ " results.append(res)\n",
+ "\n",
+ " # Сохраняем обученную модель\n",
+ " if fold_info['i_split'] == n_splits - 1: # Последний фолд\n",
+ " trained_models[model_name] = model_copy\n",
+ " \n",
+ "\n",
+ " # Результат оборачиваем в pandas DataFrame и усредняем по фолдам\n",
+ " results_df = pd.DataFrame(results)\n",
+ " average_results = results_df.groupby('model').mean()\n",
+ " average_results = average_results.reset_index()\n",
+ " return average_results, trained_models\n",
+ "\n",
+ "# %%time\n",
+ "# df_rec, trained_models = evaluate_models(interactions, models, metrics, cv, K_RECOS)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 104,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "==================== Fold 0\n",
+ "{'end': Timestamp('2021-07-26 00:00:00'),\n",
+ " 'i_split': 0,\n",
+ " 'start': Timestamp('2021-07-12 00:00:00'),\n",
+ " 'test': 398993,\n",
+ " 'test_items': 7394,\n",
+ " 'test_users': 122488,\n",
+ " 'train': 3239125,\n",
+ " 'train_items': 14730,\n",
+ " 'train_users': 646423}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 33%|███▎ | 1/3 [00:17<00:35, 17.62s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "==================== Fold 1\n",
+ "{'end': Timestamp('2021-08-09 00:00:00'),\n",
+ " 'i_split': 1,\n",
+ " 'start': Timestamp('2021-07-26 00:00:00'),\n",
+ " 'test': 458757,\n",
+ " 'test_items': 7711,\n",
+ " 'test_users': 135624,\n",
+ " 'train': 3892558,\n",
+ " 'train_items': 15085,\n",
+ " 'train_users': 742256}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 67%|██████▋ | 2/3 [00:36<00:18, 18.27s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "==================== Fold 2\n",
+ "{'end': Timestamp('2021-08-23 00:00:00'),\n",
+ " 'i_split': 2,\n",
+ " 'start': Timestamp('2021-08-09 00:00:00'),\n",
+ " 'test': 521381,\n",
+ " 'test_items': 7705,\n",
+ " 'test_users': 151629,\n",
+ " 'train': 4649162,\n",
+ " 'train_items': 15415,\n",
+ " 'train_users': 850489}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 3/3 [00:58<00:00, 19.50s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 54 s, sys: 3.04 s, total: 57.1 s\n",
+ "Wall time: 58.5 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "df_rec, trained_models = evaluate_models(interactions, models, metrics, cv, K_RECOS)"
+ ]
+ },
+ {
+ "attachments": {
+ "image.png": {
+ "image/png": ""
+ }
+ },
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 169,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def reccomend(model, dataset, list_user, data, items):\n",
+ "\n",
+ " history_all = []\n",
+ " recos_all = []\n",
+ "\n",
+ " for i in list_user:\n",
+ " recos = model.recommend(\n",
+ " users= df[Columns.User][df[Columns.User] == i].unique(),\n",
+ " dataset=dataset,\n",
+ " k=10,\n",
+ " filter_viewed=True,\n",
+ " )\n",
+ "\n",
+ " history_all.append(df[df['user_id'] == i].merge(users, how = 'left', on = 'user_id').merge(items, how = 'left', on = 'item_id'))\n",
+ " recos_all.append(recos.merge(users, how = 'left', on = 'user_id').merge(items, how = 'left', on = 'item_id'))\n",
+ "\n",
+ " return history_all, recos_all\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/hw_5_autoencoder.ipynb b/hw_5_autoencoder.ipynb
new file mode 100644
index 00000000..750a793d
--- /dev/null
+++ b/hw_5_autoencoder.ipynb
@@ -0,0 +1,1922 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "7_8DlX_2jZzT",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:32:12.524590Z",
+ "iopub.status.busy": "2023-01-22T12:32:12.523513Z",
+ "iopub.status.idle": "2023-01-22T12:32:12.529931Z",
+ "shell.execute_reply": "2023-01-22T12:32:12.528298Z",
+ "shell.execute_reply.started": "2023-01-22T12:32:12.524533Z"
+ },
+ "id": "7_8DlX_2jZzT"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import os\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "IczRXBXHjZzV",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:32:13.867299Z",
+ "iopub.status.busy": "2023-01-22T12:32:13.866000Z",
+ "iopub.status.idle": "2023-01-22T12:32:16.353124Z",
+ "shell.execute_reply": "2023-01-22T12:32:16.352004Z",
+ "shell.execute_reply.started": "2023-01-22T12:32:13.867251Z"
+ },
+ "id": "IczRXBXHjZzV"
+ },
+ "outputs": [],
+ "source": [
+ "from IPython.display import display, clear_output\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from tqdm.notebook import tqdm\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torch.nn import functional as F\n",
+ "from torch.utils.data import Dataset, DataLoader"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "mA1MfXOnjZzW",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:13.399626Z",
+ "iopub.status.busy": "2023-01-22T12:41:13.398452Z",
+ "iopub.status.idle": "2023-01-22T12:41:19.723408Z",
+ "shell.execute_reply": "2023-01-22T12:41:19.722114Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:13.399496Z"
+ },
+ "id": "mA1MfXOnjZzW"
+ },
+ "outputs": [],
+ "source": [
+ "interactions_df = pd.read_csv('interactions_processed_kion.csv')\n",
+ "users_df = pd.read_csv('users_processed_kion.csv')\n",
+ "items_df = pd.read_csv('items_processed_kion.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "G5cP9QcUjZzW",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 204
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:19.726341Z",
+ "iopub.status.busy": "2023-01-22T12:41:19.725645Z",
+ "iopub.status.idle": "2023-01-22T12:41:19.751544Z",
+ "shell.execute_reply": "2023-01-22T12:41:19.750286Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:19.726296Z"
+ },
+ "id": "G5cP9QcUjZzW",
+ "outputId": "9fc311f0-6f5b-4327-9bbc-1f6bdee3f918"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " last_watch_dt | \n",
+ " total_dur | \n",
+ " watched_pct | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 176549 | \n",
+ " 9506 | \n",
+ " 2021-05-11 | \n",
+ " 4250 | \n",
+ " 72 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 699317 | \n",
+ " 1659 | \n",
+ " 2021-05-29 | \n",
+ " 8317 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 656683 | \n",
+ " 7107 | \n",
+ " 2021-05-09 | \n",
+ " 10 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 864613 | \n",
+ " 7638 | \n",
+ " 2021-07-05 | \n",
+ " 14483 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 964868 | \n",
+ " 9506 | \n",
+ " 2021-04-30 | \n",
+ " 6725 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ " user_id item_id last_watch_dt total_dur watched_pct\n",
+ "0 176549 9506 2021-05-11 4250 72\n",
+ "1 699317 1659 2021-05-29 8317 100\n",
+ "2 656683 7107 2021-05-09 10 0\n",
+ "3 864613 7638 2021-07-05 14483 100\n",
+ "4 964868 9506 2021-04-30 6725 100"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "interactions_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "b4omWvMOjZzX",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:21.721270Z",
+ "iopub.status.busy": "2023-01-22T12:41:21.720745Z",
+ "iopub.status.idle": "2023-01-22T12:41:22.116852Z",
+ "shell.execute_reply": "2023-01-22T12:41:22.115397Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:21.721229Z"
+ },
+ "id": "b4omWvMOjZzX"
+ },
+ "outputs": [],
+ "source": [
+ "interactions_df = interactions_df[interactions_df['last_watch_dt'] < '2021-04-01']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "JAuH-fG0jZzX",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:23.195240Z",
+ "iopub.status.busy": "2023-01-22T12:41:23.194661Z",
+ "iopub.status.idle": "2023-01-22T12:41:23.202760Z",
+ "shell.execute_reply": "2023-01-22T12:41:23.201745Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:23.195188Z"
+ },
+ "id": "JAuH-fG0jZzX",
+ "outputId": "3e259108-40e9-48fc-c212-1bbd7e9e21a1"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(263874, 5)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "interactions_df.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "rWCoSNwWjZzX",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:25.368988Z",
+ "iopub.status.busy": "2023-01-22T12:41:25.367925Z",
+ "iopub.status.idle": "2023-01-22T12:41:25.558751Z",
+ "shell.execute_reply": "2023-01-22T12:41:25.557372Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:25.368937Z"
+ },
+ "id": "rWCoSNwWjZzX",
+ "outputId": "c7738105-161d-4c43-9028-21b64809ad04"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "# users: 86614\n",
+ "# users with at least 5 interactions: 14563\n"
+ ]
+ }
+ ],
+ "source": [
+ "users_interactions_count_df = interactions_df.groupby(['user_id', 'item_id']).size().groupby('user_id').size()\n",
+ "print('# users: %d' % len(users_interactions_count_df))\n",
+ "users_with_enough_interactions_df = users_interactions_count_df[users_interactions_count_df >= 5].reset_index()[['user_id']]\n",
+ "print('# users with at least 5 interactions: %d' % len(users_with_enough_interactions_df))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "qDCcr1_UjZzY",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:27.227318Z",
+ "iopub.status.busy": "2023-01-22T12:41:27.226717Z",
+ "iopub.status.idle": "2023-01-22T12:41:27.326827Z",
+ "shell.execute_reply": "2023-01-22T12:41:27.325761Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:27.227269Z"
+ },
+ "id": "qDCcr1_UjZzY",
+ "outputId": "cc44175d-eef0-42b9-839a-4c1a0efa5ce4"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "# of interactions: 263874\n",
+ "# of interactions from users with at least 5 interactions: 142670\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('# of interactions: %d' % len(interactions_df))\n",
+ "interactions_from_selected_users_df = interactions_df.merge(users_with_enough_interactions_df, \n",
+ " how = 'right',\n",
+ " left_on = 'user_id',\n",
+ " right_on = 'user_id')\n",
+ "print('# of interactions from users with at least 5 interactions: %d' % len(interactions_from_selected_users_df))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "bs9IdB8fjZzY",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:30.431311Z",
+ "iopub.status.busy": "2023-01-22T12:41:30.430823Z",
+ "iopub.status.idle": "2023-01-22T12:41:30.436607Z",
+ "shell.execute_reply": "2023-01-22T12:41:30.435654Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:30.431275Z"
+ },
+ "id": "bs9IdB8fjZzY"
+ },
+ "outputs": [],
+ "source": [
+ "import math"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "MTW_Y4iOjZzY",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 376
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:32.237281Z",
+ "iopub.status.busy": "2023-01-22T12:41:32.236079Z",
+ "iopub.status.idle": "2023-01-22T12:41:32.403346Z",
+ "shell.execute_reply": "2023-01-22T12:41:32.401909Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:32.237217Z"
+ },
+ "id": "MTW_Y4iOjZzY",
+ "outputId": "8027a8a8-0bf9-4c97-9b69-5281653709c9"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "# of unique user/item interactions: 142670\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " watched_pct | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 21 | \n",
+ " 849 | \n",
+ " 6.375039 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 21 | \n",
+ " 4345 | \n",
+ " 6.658211 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 21 | \n",
+ " 10283 | \n",
+ " 6.658211 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 21 | \n",
+ " 12261 | \n",
+ " 6.658211 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 21 | \n",
+ " 15997 | \n",
+ " 6.658211 | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 32 | \n",
+ " 952 | \n",
+ " 6.044394 | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 32 | \n",
+ " 4382 | \n",
+ " 4.954196 | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 32 | \n",
+ " 4807 | \n",
+ " 6.658211 | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 32 | \n",
+ " 10436 | \n",
+ " 6.658211 | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 32 | \n",
+ " 12132 | \n",
+ " 6.658211 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ " user_id item_id watched_pct\n",
+ "0 21 849 6.375039\n",
+ "1 21 4345 6.658211\n",
+ "2 21 10283 6.658211\n",
+ "3 21 12261 6.658211\n",
+ "4 21 15997 6.658211\n",
+ "5 32 952 6.044394\n",
+ "6 32 4382 4.954196\n",
+ "7 32 4807 6.658211\n",
+ "8 32 10436 6.658211\n",
+ "9 32 12132 6.658211"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def smooth_user_preference(x):\n",
+ " return math.log(1+x, 2)\n",
+ " \n",
+ "interactions_full_df = interactions_from_selected_users_df \\\n",
+ " .groupby(['user_id', 'item_id'])['watched_pct'].sum() \\\n",
+ " .apply(smooth_user_preference).reset_index()\n",
+ "print('# of unique user/item interactions: %d' % len(interactions_full_df))\n",
+ "interactions_full_df.head(10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "wNyqdsCxjZzZ",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:34.443808Z",
+ "iopub.status.busy": "2023-01-22T12:41:34.443346Z",
+ "iopub.status.idle": "2023-01-22T12:41:34.651267Z",
+ "shell.execute_reply": "2023-01-22T12:41:34.650080Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:34.443774Z"
+ },
+ "id": "wNyqdsCxjZzZ",
+ "outputId": "e2a2e169-78ef-4f8e-c099-eb56de49338e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "# interactions on Train set: 114136\n",
+ "# interactions on Test set: 28534\n"
+ ]
+ }
+ ],
+ "source": [
+ "interactions_train_df, interactions_test_df = train_test_split(interactions_full_df,\n",
+ " stratify=interactions_full_df['user_id'], \n",
+ " test_size=0.20,\n",
+ " random_state=42)\n",
+ "\n",
+ "print('# interactions on Train set: %d' % len(interactions_train_df))\n",
+ "print('# interactions on Test set: %d' % len(interactions_test_df))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "v1M9fBagjZzZ",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:38.570246Z",
+ "iopub.status.busy": "2023-01-22T12:41:38.568905Z",
+ "iopub.status.idle": "2023-01-22T12:41:38.583223Z",
+ "shell.execute_reply": "2023-01-22T12:41:38.581705Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:38.570182Z"
+ },
+ "id": "v1M9fBagjZzZ"
+ },
+ "outputs": [],
+ "source": [
+ "#Indexing by personId to speed up the searches during evaluation\n",
+ "interactions_full_indexed_df = interactions_full_df.set_index('user_id')\n",
+ "interactions_train_indexed_df = interactions_train_df.set_index('user_id')\n",
+ "interactions_test_indexed_df = interactions_test_df.set_index('user_id')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "Ra2TntFUjZzZ",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:42.934656Z",
+ "iopub.status.busy": "2023-01-22T12:41:42.934139Z",
+ "iopub.status.idle": "2023-01-22T12:41:42.940917Z",
+ "shell.execute_reply": "2023-01-22T12:41:42.939611Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:42.934617Z"
+ },
+ "id": "Ra2TntFUjZzZ"
+ },
+ "outputs": [],
+ "source": [
+ "def get_items_interacted(person_id, interactions_df):\n",
+ " # Get the user's data and merge in the movie information.\n",
+ " interacted_items = interactions_df.loc[person_id]['item_id']\n",
+ " return set(interacted_items if type(interacted_items) == pd.Series else [interacted_items])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "xpP7YjhRjZzZ",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:53.435832Z",
+ "iopub.status.busy": "2023-01-22T12:41:53.435366Z",
+ "iopub.status.idle": "2023-01-22T12:41:53.455616Z",
+ "shell.execute_reply": "2023-01-22T12:41:53.454525Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:53.435796Z"
+ },
+ "id": "xpP7YjhRjZzZ"
+ },
+ "outputs": [],
+ "source": [
+ "#Top-N accuracy metrics consts\n",
+ "EVAL_RANDOM_SAMPLE_NON_INTERACTED_ITEMS = 100\n",
+ "\n",
+ "class ModelEvaluator:\n",
+ "\n",
+ " def get_not_interacted_items_sample(self, person_id, sample_size, seed=42):\n",
+ " interacted_items = get_items_interacted(person_id, interactions_full_indexed_df)\n",
+ " all_items = set(articles_df['item_id'])\n",
+ " non_interacted_items = all_items - interacted_items\n",
+ "\n",
+ " random.seed(seed)\n",
+ " non_interacted_items_sample = random.sample(non_interacted_items, sample_size)\n",
+ " return set(non_interacted_items_sample)\n",
+ "\n",
+ " def _verify_hit_top_n(self, item_id, recommended_items, topn): \n",
+ " try:\n",
+ " index = next(i for i, c in enumerate(recommended_items) if c == item_id)\n",
+ " except:\n",
+ " index = -1\n",
+ " hit = int(index in range(0, topn))\n",
+ " return hit, index\n",
+ "\n",
+ " def evaluate_model_for_user(self, model, person_id):\n",
+ " #Getting the items in test set\n",
+ " interacted_values_testset = interactions_test_indexed_df.loc[person_id]\n",
+ " if type(interacted_values_testset['item_id']) == pd.Series:\n",
+ " person_interacted_items_testset = set(interacted_values_testset['item_id'])\n",
+ " else:\n",
+ " person_interacted_items_testset = set([int(interacted_values_testset['item_id'])]) \n",
+ " interacted_items_count_testset = len(person_interacted_items_testset) \n",
+ "\n",
+ " #Getting a ranked recommendation list from a model for a given user\n",
+ " person_recs_df = model.recommend_items(person_id, \n",
+ " items_to_ignore=get_items_interacted(person_id, \n",
+ " interactions_train_indexed_df), \n",
+ " topn=10000000000)\n",
+ "\n",
+ " hits_at_5_count = 0\n",
+ " hits_at_10_count = 0\n",
+ " #For each item the user has interacted in test set\n",
+ " for item_id in person_interacted_items_testset:\n",
+ " #Getting a random sample (100) items the user has not interacted \n",
+ " #(to represent items that are assumed to be no relevant to the user)\n",
+ " non_interacted_items_sample = self.get_not_interacted_items_sample(person_id, \n",
+ " sample_size=EVAL_RANDOM_SAMPLE_NON_INTERACTED_ITEMS, \n",
+ " seed=item_id%(2**32))\n",
+ "\n",
+ " #Combining the current interacted item with the 100 random items\n",
+ " items_to_filter_recs = non_interacted_items_sample.union(set([item_id]))\n",
+ "\n",
+ " #Filtering only recommendations that are either the interacted item or from a random sample of 100 non-interacted items\n",
+ " valid_recs_df = person_recs_df[person_recs_df['item_id'].isin(items_to_filter_recs)] \n",
+ " valid_recs = valid_recs_df['item_id'].values\n",
+ " #Verifying if the current interacted item is among the Top-N recommended items\n",
+ " hit_at_5, index_at_5 = self._verify_hit_top_n(item_id, valid_recs, 5)\n",
+ " hits_at_5_count += hit_at_5\n",
+ " hit_at_10, index_at_10 = self._verify_hit_top_n(item_id, valid_recs, 10)\n",
+ " hits_at_10_count += hit_at_10\n",
+ "\n",
+ " #Recall is the rate of the interacted items that are ranked among the Top-N recommended items, \n",
+ " #when mixed with a set of non-relevant items\n",
+ " recall_at_5 = hits_at_5_count / float(interacted_items_count_testset)\n",
+ " recall_at_10 = hits_at_10_count / float(interacted_items_count_testset)\n",
+ "\n",
+ " person_metrics = {'hits@5_count':hits_at_5_count, \n",
+ " 'hits@10_count':hits_at_10_count, \n",
+ " 'interacted_count': interacted_items_count_testset,\n",
+ " 'recall@5': recall_at_5,\n",
+ " 'recall@10': recall_at_10}\n",
+ " return person_metrics\n",
+ "\n",
+ " def evaluate_model(self, model):\n",
+ " #print('Running evaluation for users')\n",
+ " people_metrics = []\n",
+ " for idx, person_id in enumerate(tqdm(list(interactions_test_indexed_df.index.unique().values))):\n",
+ " #if idx % 100 == 0 and idx > 0:\n",
+ " # print('%d users processed' % idx)\n",
+ " person_metrics = self.evaluate_model_for_user(model, person_id) \n",
+ " person_metrics['user_id'] = person_id\n",
+ " people_metrics.append(person_metrics)\n",
+ " print('%d users processed' % idx)\n",
+ "\n",
+ " detailed_results_df = pd.DataFrame(people_metrics) \\\n",
+ " .sort_values('interacted_count', ascending=False)\n",
+ " \n",
+ " global_recall_at_5 = detailed_results_df['hits@5_count'].sum() / float(detailed_results_df['interacted_count'].sum())\n",
+ " global_recall_at_10 = detailed_results_df['hits@10_count'].sum() / float(detailed_results_df['interacted_count'].sum())\n",
+ " \n",
+ " global_metrics = {'modelName': model.get_model_name(),\n",
+ " 'recall@5': global_recall_at_5,\n",
+ " 'recall@10': global_recall_at_10} \n",
+ " return global_metrics, detailed_results_df\n",
+ " \n",
+ "model_evaluator = ModelEvaluator() "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "bt-Ko_HMjZza",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:41:57.779034Z",
+ "iopub.status.busy": "2023-01-22T12:41:57.777417Z",
+ "iopub.status.idle": "2023-01-22T12:41:57.787389Z",
+ "shell.execute_reply": "2023-01-22T12:41:57.785909Z",
+ "shell.execute_reply.started": "2023-01-22T12:41:57.778960Z"
+ },
+ "id": "bt-Ko_HMjZza"
+ },
+ "outputs": [],
+ "source": [
+ "from IPython.display import display, clear_output\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from tqdm.notebook import tqdm\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torch.nn import functional as F\n",
+ "from torch.utils.data import Dataset, DataLoader"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "6ySqiCo5jZza",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:42:03.271305Z",
+ "iopub.status.busy": "2023-01-22T12:42:03.270810Z",
+ "iopub.status.idle": "2023-01-22T12:42:03.278535Z",
+ "shell.execute_reply": "2023-01-22T12:42:03.277141Z",
+ "shell.execute_reply.started": "2023-01-22T12:42:03.271268Z"
+ },
+ "id": "6ySqiCo5jZza"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "# Constants\n",
+ "SEED = 42 # random seed for reproducibility\n",
+ "LR = 1e-3 # learning rate, controls the speed of the training\n",
+ "WEIGHT_DECAY = 0.01 # lambda for L2 reg. ()\n",
+ "NUM_EPOCHS = 200 # num training epochs (how many times each instance will be processed)\n",
+ "GAMMA = 0.9995 # learning rate scheduler parameter\n",
+ "BATCH_SIZE = 3000 # training batch size\n",
+ "EVAL_BATCH_SIZE = 3000 # evaluation batch size.\n",
+ "DEVICE = 'cuda' #'cuda' # device to make the calculations on"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "FtzzvibljZza",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:42:05.933060Z",
+ "iopub.status.busy": "2023-01-22T12:42:05.931911Z",
+ "iopub.status.idle": "2023-01-22T12:42:05.969002Z",
+ "shell.execute_reply": "2023-01-22T12:42:05.967458Z",
+ "shell.execute_reply.started": "2023-01-22T12:42:05.933000Z"
+ },
+ "id": "FtzzvibljZza"
+ },
+ "outputs": [],
+ "source": [
+ "total_df = interactions_train_df.append(interactions_test_indexed_df.reset_index())\n",
+ "total_df['user_id'], users_keys = total_df.user_id.factorize()\n",
+ "total_df['item_id'], items_keys = total_df.item_id.factorize()\n",
+ "\n",
+ "train_encoded = total_df.iloc[:len(interactions_train_df)].values\n",
+ "test_encoded = total_df.iloc[len(interactions_train_df):].values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "crbEdHiJjZza",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:42:09.354000Z",
+ "iopub.status.busy": "2023-01-22T12:42:09.352465Z",
+ "iopub.status.idle": "2023-01-22T12:42:09.967185Z",
+ "shell.execute_reply": "2023-01-22T12:42:09.965725Z",
+ "shell.execute_reply.started": "2023-01-22T12:42:09.353932Z"
+ },
+ "id": "crbEdHiJjZza"
+ },
+ "outputs": [],
+ "source": [
+ "from scipy.sparse import csr_matrix\n",
+ "shape = [int(total_df['user_id'].max()+1), int(total_df['item_id'].max()+1)]\n",
+ "X_train = csr_matrix((train_encoded[:, 2], (train_encoded[:, 0], train_encoded[:, 1])), shape=shape).toarray()\n",
+ "X_test = csr_matrix((test_encoded[:, 2], (test_encoded[:, 0], test_encoded[:, 1])), shape=shape).toarray()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "sFeJZsDJjZzb",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:42:12.745785Z",
+ "iopub.status.busy": "2023-01-22T12:42:12.745283Z",
+ "iopub.status.idle": "2023-01-22T12:42:12.754320Z",
+ "shell.execute_reply": "2023-01-22T12:42:12.752855Z",
+ "shell.execute_reply.started": "2023-01-22T12:42:12.745745Z"
+ },
+ "id": "sFeJZsDJjZzb"
+ },
+ "outputs": [],
+ "source": [
+ "# Initialize the DataObject, which must return an element (features vector x and target value y)\n",
+ "# for a given idx. This class must also have a length atribute\n",
+ "class UserOrientedDataset(Dataset):\n",
+ " def __init__(self, X):\n",
+ " super().__init__() # to initialize the parent class\n",
+ " self.X = X.astype(np.float32)\n",
+ " self.len = len(X)\n",
+ "\n",
+ " def __len__(self): # We use __func__ for implementing in-built python functions\n",
+ " return self.len\n",
+ "\n",
+ " def __getitem__(self, index):\n",
+ " return self.X[index]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "AoCCUSpUjZzb",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:42:16.254953Z",
+ "iopub.status.busy": "2023-01-22T12:42:16.254416Z",
+ "iopub.status.idle": "2023-01-22T12:42:17.434704Z",
+ "shell.execute_reply": "2023-01-22T12:42:17.433103Z",
+ "shell.execute_reply.started": "2023-01-22T12:42:16.254903Z"
+ },
+ "id": "AoCCUSpUjZzb"
+ },
+ "outputs": [],
+ "source": [
+ "# Initialize DataLoaders - objects, which sample instances from DataObject-s\n",
+ "train_dl = DataLoader(\n",
+ " UserOrientedDataset(X_train),\n",
+ " batch_size = BATCH_SIZE,\n",
+ " shuffle = True\n",
+ ")\n",
+ "\n",
+ "test_dl = DataLoader(\n",
+ " UserOrientedDataset(X_test),\n",
+ " batch_size = EVAL_BATCH_SIZE,\n",
+ " shuffle = False\n",
+ ")\n",
+ "\n",
+ "dls = {'train': train_dl, 'test': test_dl}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "b94CXGocjZzb",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:53:12.965059Z",
+ "iopub.status.busy": "2023-01-22T12:53:12.964527Z",
+ "iopub.status.idle": "2023-01-22T12:53:12.975037Z",
+ "shell.execute_reply": "2023-01-22T12:53:12.973690Z",
+ "shell.execute_reply.started": "2023-01-22T12:53:12.965016Z"
+ },
+ "id": "b94CXGocjZzb"
+ },
+ "outputs": [],
+ "source": [
+ "class Model(nn.Module):\n",
+ " def __init__(self, in_and_out_features = 8287):\n",
+ " super().__init__()\n",
+ " self.in_and_out_features = in_and_out_features\n",
+ " self.hidden_size = 500\n",
+ "\n",
+ " self.sequential = nn.Sequential( \n",
+ " nn.Linear(in_and_out_features, self.hidden_size), \n",
+ " nn.ReLU(), \n",
+ " nn.Linear(self.hidden_size, in_and_out_features) # Another Linear transformation\n",
+ " )\n",
+ "\n",
+ " def forward(self, x): # In the forward function, you define how your model runs, from input to output \n",
+ " x = self.sequential(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "aY_vqVZLjZzb",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:54:25.315144Z",
+ "iopub.status.busy": "2023-01-22T12:54:25.314623Z",
+ "iopub.status.idle": "2023-01-22T12:54:26.136714Z",
+ "shell.execute_reply": "2023-01-22T12:54:26.135715Z",
+ "shell.execute_reply.started": "2023-01-22T12:54:25.315101Z"
+ },
+ "id": "aY_vqVZLjZzb"
+ },
+ "outputs": [],
+ "source": [
+ "torch.manual_seed(SEED) # Fix random seed to have reproducible weights of model layers\n",
+ "\n",
+ "model = Model()\n",
+ "model.to(DEVICE)\n",
+ "\n",
+ "# Initialize GD method, which will update the weights of the model\n",
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
+ "# Initialize learning rate scheduler, which will decrease LR according to some rule\n",
+ "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA)\n",
+ "\n",
+ "def rmse_for_sparse(x_pred, x_true):\n",
+ " mask = (x_true > 0)\n",
+ " sq_diff = (x_pred * mask - x_true) ** 2\n",
+ " mse = sq_diff.sum() / mask.sum()\n",
+ " return mse ** (1/2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "LdlKerxfjZzb",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 419
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:54:33.544338Z",
+ "iopub.status.busy": "2023-01-22T12:54:33.543734Z"
+ },
+ "id": "LdlKerxfjZzb",
+ "outputId": "0bc103bb-151d-449f-b7b2-f670b8970d92"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Epoch | \n",
+ " Train RMSE | \n",
+ " Test RMSE | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0 | \n",
+ " 2.315015 | \n",
+ " 2.295504 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 1 | \n",
+ " 2.191636 | \n",
+ " 2.224912 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 2 | \n",
+ " 1.955497 | \n",
+ " 2.108439 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 3 | \n",
+ " 1.836119 | \n",
+ " 2.027701 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 4 | \n",
+ " 1.736783 | \n",
+ " 2.026640 | \n",
+ "
\n",
+ " \n",
+ " | ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " | 195 | \n",
+ " 195 | \n",
+ " 0.288658 | \n",
+ " 1.330020 | \n",
+ "
\n",
+ " \n",
+ " | 196 | \n",
+ " 196 | \n",
+ " 0.277917 | \n",
+ " 1.331115 | \n",
+ "
\n",
+ " \n",
+ " | 197 | \n",
+ " 197 | \n",
+ " 0.307082 | \n",
+ " 1.330125 | \n",
+ "
\n",
+ " \n",
+ " | 198 | \n",
+ " 198 | \n",
+ " 0.302980 | \n",
+ " 1.331673 | \n",
+ "
\n",
+ " \n",
+ " | 199 | \n",
+ " 199 | \n",
+ " 0.307337 | \n",
+ " 1.329725 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
200 rows × 3 columns
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n",
+ "
\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ " Epoch Train RMSE Test RMSE\n",
+ "0 0 2.315015 2.295504\n",
+ "1 1 2.191636 2.224912\n",
+ "2 2 1.955497 2.108439\n",
+ "3 3 1.836119 2.027701\n",
+ "4 4 1.736783 2.026640\n",
+ ".. ... ... ...\n",
+ "195 195 0.288658 1.330020\n",
+ "196 196 0.277917 1.331115\n",
+ "197 197 0.307082 1.330125\n",
+ "198 198 0.302980 1.331673\n",
+ "199 199 0.307337 1.329725\n",
+ "\n",
+ "[200 rows x 3 columns]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Training loop\n",
+ "metrics_dict = {\n",
+ " \"Epoch\": [],\n",
+ " \"Train RMSE\": [],\n",
+ " \"Test RMSE\": [],\n",
+ "}\n",
+ "\n",
+ "# Train loop\n",
+ "for epoch in range(NUM_EPOCHS):\n",
+ " metrics_dict[\"Epoch\"].append(epoch)\n",
+ " for stage in ['train', 'test']:\n",
+ " with torch.set_grad_enabled(stage == 'train'): # Whether to start building a graph for a backward pass\n",
+ " if stage == 'train':\n",
+ " model.train() # Enable some \"special\" layers (will speak about later)\n",
+ " else:\n",
+ " model.eval() # Disable some \"special\" layers (will speak about later)\n",
+ "\n",
+ " loss_at_stage = 0 \n",
+ " for batch in dls[stage]:\n",
+ " batch = batch.to(DEVICE)\n",
+ " x_pred = model(batch) # forward pass: model(x_batch) -> calls forward()\n",
+ " loss = rmse_for_sparse(x_pred, batch) # ¡Important! y_pred is always the first arg\n",
+ " if stage == \"train\":\n",
+ " loss.backward() # Calculate the gradients of all the parameters wrt loss\n",
+ " optimizer.step() # Update the parameters\n",
+ " scheduler.step()\n",
+ " optimizer.zero_grad() # Zero the saved gradient\n",
+ " loss_at_stage += loss.item() * len(batch)\n",
+ " rmse_at_stage = (loss_at_stage / len(dls[stage].dataset)) ** (1/2)\n",
+ " metrics_dict[f\"{stage.title()} RMSE\"].append(rmse_at_stage)\n",
+ " \n",
+ " if (epoch == NUM_EPOCHS - 1) or epoch % 10 == 9:\n",
+ " clear_output(wait=True)\n",
+ " display(pd.DataFrame(metrics_dict))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "ZXCPjyMajZzb",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ZXCPjyMajZzb",
+ "outputId": "a0448c4f-5e53-409b-c277-fc704e617202"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 0.3084, 2.5601, 1.0144, ..., -0.0948, -0.1467, 0.3106],\n",
+ " [ 0.1575, 0.8934, 0.1315, ..., -0.1049, 0.0096, 0.0350],\n",
+ " [ 0.6704, 1.5142, 0.6962, ..., -0.2259, -0.0353, 0.0676],\n",
+ " ...,\n",
+ " [ 0.3153, 1.1243, 0.1393, ..., -0.1222, -0.1398, 0.0617],\n",
+ " [ 0.3214, 1.9313, 0.3253, ..., -0.1548, -0.0918, -0.0392],\n",
+ " [ 0.3434, 0.9318, -0.0341, ..., -0.1714, -0.0446, 0.1267]],\n",
+ " device='cuda:0')"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " X_pred = model(torch.Tensor(X_test).to(DEVICE))\n",
+ "X_pred"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "bkSfO9fgjZzc",
+ "metadata": {
+ "id": "bkSfO9fgjZzc"
+ },
+ "outputs": [],
+ "source": [
+ "class AERecommender:\n",
+ " \n",
+ " MODEL_NAME = 'Autoencoder'\n",
+ " \n",
+ " def __init__(self, X_preds, X_train_and_val, X_test):\n",
+ "\n",
+ " self.X_preds = X_preds.cpu().detach().numpy()\n",
+ " self.X_train_and_val = X_train_and_val\n",
+ " self.X_test = X_test\n",
+ " \n",
+ " def get_model_name(self):\n",
+ " return self.MODEL_NAME\n",
+ " \n",
+ " def recommend_items(self, user_id, items_to_select_idx, topn=10, verbose=False):\n",
+ " user_preds = self.X_preds[user_id][items_to_select_idx]\n",
+ " items_idx = items_to_select_idx[np.argsort(-user_preds)[:topn]]\n",
+ "\n",
+ " # Recommend the highest predicted rating movies that the user hasn't seen yet.\n",
+ " return items_idx\n",
+ "\n",
+ " def evaluate(self, size=100):\n",
+ "\n",
+ " X_total = self.X_train_and_val + self.X_test\n",
+ "\n",
+ " true_5 = []\n",
+ " true_10 = []\n",
+ "\n",
+ " for user_id in range(len(X_test)):\n",
+ " non_zero = np.argwhere(self.X_test[user_id] > 0).ravel()\n",
+ " all_nonzero = np.argwhere(X_total[user_id] > 0).ravel()\n",
+ " select_from = np.setdiff1d(np.arange(X_total.shape[1]), all_nonzero)\n",
+ "\n",
+ " for non_zero_idx in non_zero:\n",
+ " random_non_interacted_100_items = np.random.choice(select_from, size=20, replace=False)\n",
+ " preds = self.recommend_items(user_id, np.append(random_non_interacted_100_items, non_zero_idx), topn=10)\n",
+ " true_5.append(non_zero_idx in preds[:5])\n",
+ " true_10.append(non_zero_idx in preds)\n",
+ "\n",
+ " return {\"recall@5\": np.mean(true_5), \"recall@10\": np.mean(true_10)}\n",
+ " \n",
+ "ae_recommender_model = AERecommender(X_pred, X_train, X_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "yRBbD9xmjZzc",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "yRBbD9xmjZzc",
+ "outputId": "d407d2b7-ee44-4299-9b29-046f41deb396"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'recall@5': 0.08641891035330142, 'recall@10': 0.25274264483602643}"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ae_global_metrics = ae_recommender_model.evaluate()\n",
+ "ae_global_metrics"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ydc-4MJn-KFM",
+ "metadata": {
+ "id": "ydc-4MJn-KFM"
+ },
+ "source": [
+ "Проведем эксперименты с моделями и гиперпараметрами"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "GZfxQH7Z-hMK",
+ "metadata": {
+ "id": "GZfxQH7Z-hMK"
+ },
+ "outputs": [],
+ "source": [
+ "def train_model():\n",
+ " torch.manual_seed(SEED) # Fix random seed to have reproducible weights of model layers\n",
+ "\n",
+ " model = Model()\n",
+ " model.to(DEVICE)\n",
+ "\n",
+ " # Initialize GD method, which will update the weights of the model\n",
+ " optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
+ " # Initialize learning rate scheduler, which will decrease LR according to some rule\n",
+ " scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA)\n",
+ "\n",
+ "\n",
+ " # Training loop\n",
+ " metrics_dict = {\n",
+ " \"Epoch\": [],\n",
+ " \"Train RMSE\": [],\n",
+ " \"Test RMSE\": [],\n",
+ " }\n",
+ "\n",
+ " # Train loop\n",
+ " for epoch in range(NUM_EPOCHS):\n",
+ " metrics_dict[\"Epoch\"].append(epoch)\n",
+ " for stage in ['train', 'test']:\n",
+ " with torch.set_grad_enabled(stage == 'train'): # Whether to start building a graph for a backward pass\n",
+ " if stage == 'train':\n",
+ " model.train() # Enable some \"special\" layers (will speak about later)\n",
+ " else:\n",
+ " model.eval() # Disable some \"special\" layers (will speak about later)\n",
+ "\n",
+ " loss_at_stage = 0 \n",
+ " for batch in dls[stage]:\n",
+ " batch = batch.to(DEVICE)\n",
+ " x_pred = model(batch) # forward pass: model(x_batch) -> calls forward()\n",
+ " loss = rmse_for_sparse(x_pred, batch) # ¡Important! y_pred is always the first arg\n",
+ " if stage == \"train\":\n",
+ " loss.backward() # Calculate the gradients of all the parameters wrt loss\n",
+ " optimizer.step() # Update the parameters\n",
+ " scheduler.step()\n",
+ " optimizer.zero_grad() # Zero the saved gradient\n",
+ " loss_at_stage += loss.item() * len(batch)\n",
+ " rmse_at_stage = (loss_at_stage / len(dls[stage].dataset)) ** (1/2)\n",
+ " metrics_dict[f\"{stage.title()} RMSE\"].append(rmse_at_stage)\n",
+ " \n",
+ " with torch.no_grad():\n",
+ " X_pred = model(torch.Tensor(X_test).to(DEVICE))\n",
+ "\n",
+ " ae_recommender_model = AERecommender(X_pred, X_train, X_train)\n",
+ "\n",
+ " ae_global_metrics = ae_recommender_model.evaluate()\n",
+ "\n",
+ " metrics_dict[\"recall@5\"] = ae_global_metrics[\"recall@5\"]\n",
+ " metrics_dict[\"recall@10\"] = ae_global_metrics[\"recall@10\"]\n",
+ "\n",
+ "\n",
+ " return metrics_dict"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "iYS06bYkA5uD",
+ "metadata": {
+ "id": "iYS06bYkA5uD"
+ },
+ "source": [
+ "C изначальной архитектурой"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "s69HDH9P-PZl",
+ "metadata": {
+ "id": "s69HDH9P-PZl"
+ },
+ "outputs": [],
+ "source": [
+ "class Model(nn.Module):\n",
+ " def __init__(self, in_and_out_features = 8287):\n",
+ " super().__init__()\n",
+ " self.in_and_out_features = in_and_out_features\n",
+ " self.hidden_size = 500\n",
+ "\n",
+ " self.sequential = nn.Sequential( \n",
+ " nn.Linear(in_and_out_features, self.hidden_size), \n",
+ " nn.ReLU(), \n",
+ " nn.Linear(self.hidden_size, in_and_out_features) # Another Linear transformation\n",
+ " )\n",
+ "\n",
+ " def forward(self, x): # In the forward function, you define how your model runs, from input to output \n",
+ " x = self.sequential(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "TytUsH6vA9Wo",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "TytUsH6vA9Wo",
+ "outputId": "464573c4-6c3f-4b04-ac32-3ea09fb84f08"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "lr:0.001 ne:0.001 bs:3000 ....\n",
+ "lr:0.001 ne:0.001 bs:4500 ....\n",
+ "lr:0.001 ne:0.001 bs:3000 ....\n",
+ "lr:0.001 ne:0.001 bs:4500 ....\n",
+ "lr:0.0003 ne:0.0003 bs:3000 ....\n",
+ "lr:0.0003 ne:0.0003 bs:4500 ....\n",
+ "lr:0.0003 ne:0.0003 bs:3000 ....\n",
+ "lr:0.0003 ne:0.0003 bs:4500 ....\n"
+ ]
+ }
+ ],
+ "source": [
+ "first_arch_metrics = {}\n",
+ "\n",
+ "for lr in [0.001, 0.0003]:\n",
+ " for ne in [100, 200]:\n",
+ " for bs in [3000, 4500]:\n",
+ " \n",
+ " print(f\"lr:{lr} ne:{lr} bs:{bs} ....\" )\n",
+ "\n",
+ " LR = lr\n",
+ " NUM_EPOCHS = ne\n",
+ " BATCH_SIZE = bs\n",
+ "\n",
+ " first_arch_metrics[f\"lr:{lr} ne:{ne} bs:{bs}\"] = train_model()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "HnEm5GLZDAuC",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HnEm5GLZDAuC",
+ "outputId": "d652f3d6-c81a-4e35-e070-7350ce956120"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "lr:0.001 ne:100 bs:3000 0.0856485318926931 0.24734999561176826\n",
+ "lr:0.001 ne:100 bs:4500 0.08339590626737009 0.2456629642992969\n",
+ "lr:0.001 ne:200 bs:3000 0.08698450466615308 0.2526061220708553\n",
+ "lr:0.001 ne:200 bs:4500 0.0867699688923128 0.2529181741055321\n",
+ "lr:0.0003 ne:100 bs:3000 0.08751109247467015 0.25363979443572215\n",
+ "lr:0.0003 ne:100 bs:4500 0.08879830711771187 0.25470272167883995\n",
+ "lr:0.0003 ne:200 bs:3000 0.08135781641588735 0.23005061093937415\n",
+ "lr:0.0003 ne:200 bs:4500 0.08193316235482266 0.23188391664310024\n"
+ ]
+ }
+ ],
+ "source": [
+ "for i in first_arch_metrics.keys():\n",
+ " print(i, first_arch_metrics[i]['recall@5'], first_arch_metrics[i]['recall@10'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "ZaDleIoiPyCJ",
+ "metadata": {
+ "id": "ZaDleIoiPyCJ"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "DOeEZG5_A9p4",
+ "metadata": {
+ "id": "DOeEZG5_A9p4"
+ },
+ "source": [
+ "Усложним архитектуру"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "rTOhYgiX-fEr",
+ "metadata": {
+ "id": "rTOhYgiX-fEr"
+ },
+ "outputs": [],
+ "source": [
+ "class Model(nn.Module):\n",
+ " def __init__(self, in_and_out_features = 8287):\n",
+ " super().__init__()\n",
+ " self.in_and_out_features = in_and_out_features\n",
+ " self.hidden_size = 512\n",
+ "\n",
+ " self.sequential = nn.Sequential( \n",
+ " nn.Linear(in_and_out_features, 4096), \n",
+ " nn.ReLU(), \n",
+ "\n",
+ " nn.Linear(4096, self.hidden_size), \n",
+ " nn.ReLU(),\n",
+ "\n",
+ " nn.Linear(self.hidden_size, 4096), \n",
+ " nn.ReLU(), \n",
+ "\n",
+ " nn.Linear(4096, in_and_out_features) # Another Linear transformation\n",
+ " )\n",
+ "\n",
+ " def forward(self, x): # In the forward function, you define how your model runs, from input to output \n",
+ " x = self.sequential(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "TPRykgiN-fNe",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "TPRykgiN-fNe",
+ "outputId": "a32247b1-3a86-479d-aa0f-df75119e18e2"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "lr:0.001 ne:100 bs:3000 ....\n",
+ "lr:0.001 ne:100 bs:4500 ....\n",
+ "lr:0.001 ne:200 bs:3000 ....\n",
+ "lr:0.001 ne:200 bs:4500 ....\n",
+ "lr:0.0003 ne:100 bs:3000 ....\n",
+ "lr:0.0003 ne:100 bs:4500 ....\n",
+ "lr:0.0003 ne:200 bs:3000 ....\n",
+ "lr:0.0003 ne:200 bs:4500 ....\n"
+ ]
+ }
+ ],
+ "source": [
+ "second_arch_metrics = {}\n",
+ "\n",
+ "for lr in [0.001, 0.0003]:\n",
+ " for ne in [100, 200]:\n",
+ " for bs in [3000, 4500]:\n",
+ " \n",
+ " print(f\"lr:{lr} ne:{ne} bs:{bs} ....\" )\n",
+ "\n",
+ " LR = lr\n",
+ " NUM_EPOCHS = ne\n",
+ " BATCH_SIZE = bs\n",
+ "\n",
+ " second_arch_metrics[f\"lr:{lr} ne:{ne} bs:{bs}\"] = train_model()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "wMGBNnslD4ax",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "wMGBNnslD4ax",
+ "outputId": "76b97e64-342e-4ef5-b71e-f6a88c7daf36"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "lr:0.001 ne:100 bs:3000 0.14852701688006476 0.363608881781037\n",
+ "lr:0.001 ne:100 bs:4500 0.14894633680166167 0.3632090651116074\n",
+ "lr:0.001 ne:200 bs:3000 0.15524588725169922 0.35925965654772934\n",
+ "lr:0.001 ne:200 bs:4500 0.1548265673301023 0.35853803621753927\n",
+ "lr:0.0003 ne:100 bs:3000 0.15456327342584375 0.37245360663890703\n",
+ "lr:0.0003 ne:100 bs:4500 0.15338332666972218 0.3731167172125952\n",
+ "lr:0.0003 ne:200 bs:3000 0.15394892098257384 0.3672852448145728\n",
+ "lr:0.0003 ne:200 bs:4500 0.1538026465913191 0.36697319277989604\n"
+ ]
+ }
+ ],
+ "source": [
+ "for i in second_arch_metrics.keys():\n",
+ " print(i, second_arch_metrics[i]['recall@5'], second_arch_metrics[i]['recall@10'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "k6zRyd-uD0Fy",
+ "metadata": {
+ "id": "k6zRyd-uD0Fy"
+ },
+ "source": [
+ "Добавим еще слоев: "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "08GqJ7iu-fPo",
+ "metadata": {
+ "id": "08GqJ7iu-fPo"
+ },
+ "outputs": [],
+ "source": [
+ "class Model(nn.Module):\n",
+ " def __init__(self, in_and_out_features = 8287):\n",
+ " super().__init__()\n",
+ " self.in_and_out_features = in_and_out_features\n",
+ " self.hidden_size = 512\n",
+ "\n",
+ " self.sequential = nn.Sequential( \n",
+ " nn.Linear(in_and_out_features, 6000), \n",
+ " nn.ReLU(), \n",
+ "\n",
+ " nn.Linear(6000, 3000), \n",
+ " nn.ReLU(),\n",
+ "\n",
+ " nn.Linear(3000, 1024), \n",
+ " nn.ReLU(),\n",
+ "\n",
+ " nn.Linear(1024, self.hidden_size), \n",
+ " nn.ReLU(),\n",
+ "\n",
+ " nn.Linear(self.hidden_size, 1024), \n",
+ " nn.ReLU(),\n",
+ "\n",
+ " nn.Linear(1024, 3000), \n",
+ " nn.ReLU(),\n",
+ "\n",
+ " nn.Linear(3000, 6000), \n",
+ " nn.ReLU(), \n",
+ "\n",
+ " nn.Linear(6000, in_and_out_features) # Another Linear transformation\n",
+ " )\n",
+ "\n",
+ " def forward(self, x): # In the forward function, you define how your model runs, from input to output \n",
+ " x = self.sequential(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "AV4bbBpd-fSg",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "AV4bbBpd-fSg",
+ "outputId": "2a985b96-d452-490b-d73c-419e0341b0d8"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "lr:0.0003 ne:100 bs:3000 ....\n",
+ "lr:0.0003 ne:100 bs:4500 ....\n",
+ "lr:0.0003 ne:200 bs:3000 ....\n",
+ "lr:0.0003 ne:200 bs:4500 ....\n"
+ ]
+ }
+ ],
+ "source": [
+ "third_arch_metrics = {}\n",
+ "\n",
+ "for lr in [0.0003]:\n",
+ " for ne in [100, 200]:\n",
+ " for bs in [3000, 4500]:\n",
+ " \n",
+ " print(f\"lr:{lr} ne:{ne} bs:{bs} ....\" )\n",
+ "\n",
+ " LR = lr\n",
+ " NUM_EPOCHS = ne\n",
+ " BATCH_SIZE = bs\n",
+ "\n",
+ " third_arch_metrics[f\"lr:{lr} ne:{ne} bs:{bs}\"] = train_model()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "v1gCEb8aFQc-",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "v1gCEb8aFQc-",
+ "outputId": "f7084cf6-b161-4951-cc6d-1abe32766205"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "lr:0.0003 ne:100 bs:3000 0.24635532975123603 0.6131237383833754\n",
+ "lr:0.0003 ne:100 bs:4500 0.2430007703784606 0.6135430583049724\n",
+ "lr:0.0003 ne:200 bs:3000 0.2589251757730602 0.6017533423698401\n",
+ "lr:0.0003 ne:200 bs:4500 0.2589739339034784 0.6040157196212469\n"
+ ]
+ }
+ ],
+ "source": [
+ "for i in third_arch_metrics.keys():\n",
+ " print(i, third_arch_metrics[i]['recall@5'], third_arch_metrics[i]['recall@10'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "k0qGe8sVaZq4",
+ "metadata": {
+ "id": "k0qGe8sVaZq4"
+ },
+ "source": [
+ "Модель обучена. Лучшей моделью является модель последней архитектуры , со следующими подобранными гипперпараметрам:\n",
+ "\n",
+ "* LR: 0.0003\n",
+ "* NUM_EPOCHS: 200\n",
+ "* BATCH_SIZE: 4500\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "sd7WVXXYo7H1",
+ "metadata": {
+ "id": "sd7WVXXYo7H1"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.15"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/hw_5_dssm.ipynb b/hw_5_dssm.ipynb
new file mode 100644
index 00000000..258bae7e
--- /dev/null
+++ b/hw_5_dssm.ipynb
@@ -0,0 +1,4034 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:22.841107Z",
+ "iopub.status.busy": "2023-01-22T16:23:22.840365Z",
+ "iopub.status.idle": "2023-01-22T16:23:22.850076Z",
+ "shell.execute_reply": "2023-01-22T16:23:22.848844Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:22.841044Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import ast\n",
+ "import json\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import os\n",
+ "import pandas as pd\n",
+ "import pickle\n",
+ "import tensorflow as tf\n",
+ "import tensorflow.keras.backend as K\n",
+ "import warnings\n",
+ "warnings.filterwarnings('ignore')\n",
+ "\n",
+ "from collections import Counter\n",
+ "from random import randint, random\n",
+ "from scipy.sparse import coo_matrix, hstack\n",
+ "from sklearn.metrics.pairwise import euclidean_distances, cosine_distances, cosine_similarity\n",
+ "from sklearn.metrics.pairwise import euclidean_distances as ED\n",
+ "from tensorflow import keras\n",
+ "from tqdm import tqdm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:39:51.661446Z",
+ "start_time": "2021-10-28T18:39:51.563879Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:22.852847Z",
+ "iopub.status.busy": "2023-01-22T16:23:22.851743Z",
+ "iopub.status.idle": "2023-01-22T16:23:29.088896Z",
+ "shell.execute_reply": "2023-01-22T16:23:29.087873Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:22.852800Z"
+ },
+ "id": "25508632"
+ },
+ "outputs": [],
+ "source": [
+ "interactions_df = pd.read_csv('interactions_processed_kion.csv')\n",
+ "users_df = pd.read_csv('users_processed_kion.csv')\n",
+ "items_df = pd.read_csv('items_processed_kion.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:35.447336Z",
+ "start_time": "2021-10-28T18:40:35.434541Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:29.097384Z",
+ "iopub.status.busy": "2023-01-22T16:23:29.094877Z",
+ "iopub.status.idle": "2023-01-22T16:23:29.123826Z",
+ "shell.execute_reply": "2023-01-22T16:23:29.123005Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:29.097341Z"
+ },
+ "id": "f5eacb31",
+ "outputId": "37b5c35b-4f4b-48ea-9012-a6ce7eed31c7"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " age | \n",
+ " income | \n",
+ " sex | \n",
+ " kids_flg | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 973171 | \n",
+ " age_25_34 | \n",
+ " income_60_90 | \n",
+ " M | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 962099 | \n",
+ " age_18_24 | \n",
+ " income_20_40 | \n",
+ " M | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1047345 | \n",
+ " age_45_54 | \n",
+ " income_40_60 | \n",
+ " F | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 721985 | \n",
+ " age_45_54 | \n",
+ " income_20_40 | \n",
+ " F | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 704055 | \n",
+ " age_35_44 | \n",
+ " income_60_90 | \n",
+ " F | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id age income sex kids_flg\n",
+ "0 973171 age_25_34 income_60_90 M True\n",
+ "1 962099 age_18_24 income_20_40 M False\n",
+ "2 1047345 age_45_54 income_40_60 F False\n",
+ "3 721985 age_45_54 income_20_40 F False\n",
+ "4 704055 age_35_44 income_60_90 F False"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "users_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:36.103997Z",
+ "start_time": "2021-10-28T18:40:36.094699Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:29.130158Z",
+ "iopub.status.busy": "2023-01-22T16:23:29.127963Z",
+ "iopub.status.idle": "2023-01-22T16:23:29.145149Z",
+ "shell.execute_reply": "2023-01-22T16:23:29.144033Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:29.130122Z"
+ },
+ "id": "61669d0d"
+ },
+ "outputs": [],
+ "source": [
+ "items_df = items_df.rename(columns = {'id' : 'item_id'})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:36.378293Z",
+ "start_time": "2021-10-28T18:40:36.370946Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:29.146754Z",
+ "iopub.status.busy": "2023-01-22T16:23:29.146394Z",
+ "iopub.status.idle": "2023-01-22T16:23:29.166993Z",
+ "shell.execute_reply": "2023-01-22T16:23:29.165796Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:29.146717Z"
+ },
+ "id": "25f4462e",
+ "outputId": "5cc6c801-f866-4b52-aada-f5226a5ebc21"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item_id | \n",
+ " content_type | \n",
+ " title | \n",
+ " title_orig | \n",
+ " genres | \n",
+ " countries | \n",
+ " for_kids | \n",
+ " age_rating | \n",
+ " studios | \n",
+ " directors | \n",
+ " actors | \n",
+ " description | \n",
+ " keywords | \n",
+ " release_year_cat | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 10711 | \n",
+ " film | \n",
+ " поговори с ней | \n",
+ " Hable con ella | \n",
+ " драмы, зарубежные, детективы, мелодрамы | \n",
+ " испания | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " педро альмодовар | \n",
+ " Адольфо Фернандес, Ана Фернандес, Дарио Гранди... | \n",
+ " Мелодрама легендарного Педро Альмодовара «Пого... | \n",
+ " Поговори, ней, 2002, Испания, друзья, любовь, ... | \n",
+ " 2000-2010 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2508 | \n",
+ " film | \n",
+ " голые перцы | \n",
+ " Search Party | \n",
+ " зарубежные, приключения, комедии | \n",
+ " сша | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " скот армстронг | \n",
+ " Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... | \n",
+ " Уморительная современная комедия на популярную... | \n",
+ " Голые, перцы, 2014, США, друзья, свадьбы, прео... | \n",
+ " 2010-2020 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 10716 | \n",
+ " film | \n",
+ " тактическая сила | \n",
+ " Tactical Force | \n",
+ " криминал, зарубежные, триллеры, боевики, комедии | \n",
+ " канада | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " адам п. калтраро | \n",
+ " Адриан Холмс, Даррен Шалави, Джерри Вассерман,... | \n",
+ " Профессиональный рестлер Стив Остин («Все или ... | \n",
+ " Тактическая, сила, 2011, Канада, бандиты, ганг... | \n",
+ " 2010-2020 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 7868 | \n",
+ " film | \n",
+ " 45 лет | \n",
+ " 45 Years | \n",
+ " драмы, зарубежные, мелодрамы | \n",
+ " великобритания | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " эндрю хэй | \n",
+ " Александра Риддлстон-Барретт, Джеральдин Джейм... | \n",
+ " Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей... | \n",
+ " 45, лет, 2015, Великобритания, брак, жизнь, лю... | \n",
+ " 2010-2020 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 16268 | \n",
+ " film | \n",
+ " все решает мгновение | \n",
+ " NaN | \n",
+ " драмы, спорт, советские, мелодрамы | \n",
+ " ссср | \n",
+ " False | \n",
+ " 12.0 | \n",
+ " ленфильм | \n",
+ " виктор садовский | \n",
+ " Александр Абдулов, Александр Демьяненко, Алекс... | \n",
+ " Расчетливая чаровница из советского кинохита «... | \n",
+ " Все, решает, мгновение, 1978, СССР, сильные, ж... | \n",
+ " 1970-1980 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " item_id content_type title title_orig \\\n",
+ "0 10711 film поговори с ней Hable con ella \n",
+ "1 2508 film голые перцы Search Party \n",
+ "2 10716 film тактическая сила Tactical Force \n",
+ "3 7868 film 45 лет 45 Years \n",
+ "4 16268 film все решает мгновение NaN \n",
+ "\n",
+ " genres countries for_kids \\\n",
+ "0 драмы, зарубежные, детективы, мелодрамы испания False \n",
+ "1 зарубежные, приключения, комедии сша False \n",
+ "2 криминал, зарубежные, триллеры, боевики, комедии канада False \n",
+ "3 драмы, зарубежные, мелодрамы великобритания False \n",
+ "4 драмы, спорт, советские, мелодрамы ссср False \n",
+ "\n",
+ " age_rating studios directors \\\n",
+ "0 16.0 unknown педро альмодовар \n",
+ "1 16.0 unknown скот армстронг \n",
+ "2 16.0 unknown адам п. калтраро \n",
+ "3 16.0 unknown эндрю хэй \n",
+ "4 12.0 ленфильм виктор садовский \n",
+ "\n",
+ " actors \\\n",
+ "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n",
+ "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n",
+ "2 Адриан Холмс, Даррен Шалави, Джерри Вассерман,... \n",
+ "3 Александра Риддлстон-Барретт, Джеральдин Джейм... \n",
+ "4 Александр Абдулов, Александр Демьяненко, Алекс... \n",
+ "\n",
+ " description \\\n",
+ "0 Мелодрама легендарного Педро Альмодовара «Пого... \n",
+ "1 Уморительная современная комедия на популярную... \n",
+ "2 Профессиональный рестлер Стив Остин («Все или ... \n",
+ "3 Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей... \n",
+ "4 Расчетливая чаровница из советского кинохита «... \n",
+ "\n",
+ " keywords release_year_cat \n",
+ "0 Поговори, ней, 2002, Испания, друзья, любовь, ... 2000-2010 \n",
+ "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... 2010-2020 \n",
+ "2 Тактическая, сила, 2011, Канада, бандиты, ганг... 2010-2020 \n",
+ "3 45, лет, 2015, Великобритания, брак, жизнь, лю... 2010-2020 \n",
+ "4 Все, решает, мгновение, 1978, СССР, сильные, ж... 1970-1980 "
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:36.607688Z",
+ "start_time": "2021-10-28T18:40:36.597640Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:29.169473Z",
+ "iopub.status.busy": "2023-01-22T16:23:29.168713Z",
+ "iopub.status.idle": "2023-01-22T16:23:29.183035Z",
+ "shell.execute_reply": "2023-01-22T16:23:29.181327Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:29.169432Z"
+ },
+ "id": "b41964d3",
+ "outputId": "b4c8f3d5-e7af-4e29-d2e8-0defb6993b35"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " last_watch_dt | \n",
+ " total_dur | \n",
+ " watched_pct | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 176549 | \n",
+ " 9506 | \n",
+ " 2021-05-11 | \n",
+ " 4250 | \n",
+ " 72 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 699317 | \n",
+ " 1659 | \n",
+ " 2021-05-29 | \n",
+ " 8317 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 656683 | \n",
+ " 7107 | \n",
+ " 2021-05-09 | \n",
+ " 10 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 864613 | \n",
+ " 7638 | \n",
+ " 2021-07-05 | \n",
+ " 14483 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 964868 | \n",
+ " 9506 | \n",
+ " 2021-04-30 | \n",
+ " 6725 | \n",
+ " 100 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id item_id last_watch_dt total_dur watched_pct\n",
+ "0 176549 9506 2021-05-11 4250 72\n",
+ "1 699317 1659 2021-05-29 8317 100\n",
+ "2 656683 7107 2021-05-09 10 0\n",
+ "3 864613 7638 2021-07-05 14483 100\n",
+ "4 964868 9506 2021-04-30 6725 100"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "interactions_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cd252422"
+ },
+ "source": [
+ "## Готовим фичи пользователей"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pBdccMPAr7KR"
+ },
+ "source": [
+ "Посмотрим, какие фичи в датасете фильмов являются категориальными и закодируем их с помощью one-hot encoding."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:37.156260Z",
+ "start_time": "2021-10-28T18:40:37.138422Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:29.185708Z",
+ "iopub.status.busy": "2023-01-22T16:23:29.184841Z",
+ "iopub.status.idle": "2023-01-22T16:23:29.504659Z",
+ "shell.execute_reply": "2023-01-22T16:23:29.503366Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:29.185668Z"
+ },
+ "id": "692270ac",
+ "outputId": "7491ab1f-f9fb-4921-e383-7ecf5569e999"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " age_age_18_24 | \n",
+ " age_age_25_34 | \n",
+ " age_age_35_44 | \n",
+ " age_age_45_54 | \n",
+ " age_age_55_64 | \n",
+ " age_age_65_inf | \n",
+ " age_age_unknown | \n",
+ " income_income_0_20 | \n",
+ " income_income_150_inf | \n",
+ " income_income_20_40 | \n",
+ " income_income_40_60 | \n",
+ " income_income_60_90 | \n",
+ " income_income_90_150 | \n",
+ " income_income_unknown | \n",
+ " sex_F | \n",
+ " sex_M | \n",
+ " sex_sex_unknown | \n",
+ " kids_flg_False | \n",
+ " kids_flg_True | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 973171 | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 962099 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1047345 | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 721985 | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 704055 | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id age_age_18_24 age_age_25_34 age_age_35_44 age_age_45_54 \\\n",
+ "0 973171 False True False False \n",
+ "1 962099 True False False False \n",
+ "2 1047345 False False False True \n",
+ "3 721985 False False False True \n",
+ "4 704055 False False True False \n",
+ "\n",
+ " age_age_55_64 age_age_65_inf age_age_unknown income_income_0_20 \\\n",
+ "0 False False False False \n",
+ "1 False False False False \n",
+ "2 False False False False \n",
+ "3 False False False False \n",
+ "4 False False False False \n",
+ "\n",
+ " income_income_150_inf income_income_20_40 income_income_40_60 \\\n",
+ "0 False False False \n",
+ "1 False True False \n",
+ "2 False False True \n",
+ "3 False True False \n",
+ "4 False False False \n",
+ "\n",
+ " income_income_60_90 income_income_90_150 income_income_unknown sex_F \\\n",
+ "0 True False False False \n",
+ "1 False False False False \n",
+ "2 False False False True \n",
+ "3 False False False True \n",
+ "4 True False False True \n",
+ "\n",
+ " sex_M sex_sex_unknown kids_flg_False kids_flg_True \n",
+ "0 True False False True \n",
+ "1 True False True False \n",
+ "2 False False True False \n",
+ "3 False False True False \n",
+ "4 False False True False "
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "user_cat_feats = [\"age\", \"income\", \"sex\", \"kids_flg\"]\n",
+ "# из исходного датафрейма оставим только item_id - этот признак нам понадобится позже\n",
+ "# для того, чтобы маппить айтемы из датафрейма с фильмами с айтемами \n",
+ "# из датафрейма с взаимодействиями\n",
+ "users_ohe_df = users_df.user_id\n",
+ "for feat in user_cat_feats:\n",
+ " # получаем датафрейм с one-hot encoding для каждой категориальной фичи\n",
+ " ohe_feat_df = pd.get_dummies(users_df[feat], prefix=feat)\n",
+ " # конкатенируем ohe-hot датафрейм с датафреймом, \n",
+ " # который мы получили на предыдущем шаге\n",
+ " users_ohe_df = pd.concat([users_ohe_df, ohe_feat_df], axis=1)\n",
+ "\n",
+ "users_ohe_df.head()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "74cdbd93"
+ },
+ "source": [
+ "## Готовим фичи айтемов"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5kHzJ91Mr35c"
+ },
+ "source": [
+ "Кодируем их точно так же - one-hot'ом."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:29.507174Z",
+ "iopub.status.busy": "2023-01-22T16:23:29.506716Z",
+ "iopub.status.idle": "2023-01-22T16:23:29.528115Z",
+ "shell.execute_reply": "2023-01-22T16:23:29.526826Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:29.507133Z"
+ },
+ "id": "-2Wd9upSsCle",
+ "outputId": "671c2446-81f5-4e32-e24f-3aec9c8a2076"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item_id | \n",
+ " content_type | \n",
+ " title | \n",
+ " title_orig | \n",
+ " genres | \n",
+ " countries | \n",
+ " for_kids | \n",
+ " age_rating | \n",
+ " studios | \n",
+ " directors | \n",
+ " actors | \n",
+ " description | \n",
+ " keywords | \n",
+ " release_year_cat | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 10711 | \n",
+ " film | \n",
+ " поговори с ней | \n",
+ " Hable con ella | \n",
+ " драмы, зарубежные, детективы, мелодрамы | \n",
+ " испания | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " педро альмодовар | \n",
+ " Адольфо Фернандес, Ана Фернандес, Дарио Гранди... | \n",
+ " Мелодрама легендарного Педро Альмодовара «Пого... | \n",
+ " Поговори, ней, 2002, Испания, друзья, любовь, ... | \n",
+ " 2000-2010 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2508 | \n",
+ " film | \n",
+ " голые перцы | \n",
+ " Search Party | \n",
+ " зарубежные, приключения, комедии | \n",
+ " сша | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " скот армстронг | \n",
+ " Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... | \n",
+ " Уморительная современная комедия на популярную... | \n",
+ " Голые, перцы, 2014, США, друзья, свадьбы, прео... | \n",
+ " 2010-2020 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 10716 | \n",
+ " film | \n",
+ " тактическая сила | \n",
+ " Tactical Force | \n",
+ " криминал, зарубежные, триллеры, боевики, комедии | \n",
+ " канада | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " адам п. калтраро | \n",
+ " Адриан Холмс, Даррен Шалави, Джерри Вассерман,... | \n",
+ " Профессиональный рестлер Стив Остин («Все или ... | \n",
+ " Тактическая, сила, 2011, Канада, бандиты, ганг... | \n",
+ " 2010-2020 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 7868 | \n",
+ " film | \n",
+ " 45 лет | \n",
+ " 45 Years | \n",
+ " драмы, зарубежные, мелодрамы | \n",
+ " великобритания | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " эндрю хэй | \n",
+ " Александра Риддлстон-Барретт, Джеральдин Джейм... | \n",
+ " Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей... | \n",
+ " 45, лет, 2015, Великобритания, брак, жизнь, лю... | \n",
+ " 2010-2020 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 16268 | \n",
+ " film | \n",
+ " все решает мгновение | \n",
+ " NaN | \n",
+ " драмы, спорт, советские, мелодрамы | \n",
+ " ссср | \n",
+ " False | \n",
+ " 12.0 | \n",
+ " ленфильм | \n",
+ " виктор садовский | \n",
+ " Александр Абдулов, Александр Демьяненко, Алекс... | \n",
+ " Расчетливая чаровница из советского кинохита «... | \n",
+ " Все, решает, мгновение, 1978, СССР, сильные, ж... | \n",
+ " 1970-1980 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " item_id content_type title title_orig \\\n",
+ "0 10711 film поговори с ней Hable con ella \n",
+ "1 2508 film голые перцы Search Party \n",
+ "2 10716 film тактическая сила Tactical Force \n",
+ "3 7868 film 45 лет 45 Years \n",
+ "4 16268 film все решает мгновение NaN \n",
+ "\n",
+ " genres countries for_kids \\\n",
+ "0 драмы, зарубежные, детективы, мелодрамы испания False \n",
+ "1 зарубежные, приключения, комедии сша False \n",
+ "2 криминал, зарубежные, триллеры, боевики, комедии канада False \n",
+ "3 драмы, зарубежные, мелодрамы великобритания False \n",
+ "4 драмы, спорт, советские, мелодрамы ссср False \n",
+ "\n",
+ " age_rating studios directors \\\n",
+ "0 16.0 unknown педро альмодовар \n",
+ "1 16.0 unknown скот армстронг \n",
+ "2 16.0 unknown адам п. калтраро \n",
+ "3 16.0 unknown эндрю хэй \n",
+ "4 12.0 ленфильм виктор садовский \n",
+ "\n",
+ " actors \\\n",
+ "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n",
+ "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n",
+ "2 Адриан Холмс, Даррен Шалави, Джерри Вассерман,... \n",
+ "3 Александра Риддлстон-Барретт, Джеральдин Джейм... \n",
+ "4 Александр Абдулов, Александр Демьяненко, Алекс... \n",
+ "\n",
+ " description \\\n",
+ "0 Мелодрама легендарного Педро Альмодовара «Пого... \n",
+ "1 Уморительная современная комедия на популярную... \n",
+ "2 Профессиональный рестлер Стив Остин («Все или ... \n",
+ "3 Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей... \n",
+ "4 Расчетливая чаровница из советского кинохита «... \n",
+ "\n",
+ " keywords release_year_cat \n",
+ "0 Поговори, ней, 2002, Испания, друзья, любовь, ... 2000-2010 \n",
+ "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... 2010-2020 \n",
+ "2 Тактическая, сила, 2011, Канада, бандиты, ганг... 2010-2020 \n",
+ "3 45, лет, 2015, Великобритания, брак, жизнь, лю... 2010-2020 \n",
+ "4 Все, решает, мгновение, 1978, СССР, сильные, ж... 1970-1980 "
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:37.792147Z",
+ "start_time": "2021-10-28T18:40:37.537501Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:29.534806Z",
+ "iopub.status.busy": "2023-01-22T16:23:29.533869Z",
+ "iopub.status.idle": "2023-01-22T16:23:30.291045Z",
+ "shell.execute_reply": "2023-01-22T16:23:30.289998Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:29.534762Z"
+ },
+ "id": "7a94ef7e",
+ "outputId": "1ea7a769-8c2d-43d5-f2cb-47500bc0a7ba"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item_id | \n",
+ " content_type_film | \n",
+ " content_type_series | \n",
+ " release_year_cat_1920-1930 | \n",
+ " release_year_cat_1930-1940 | \n",
+ " release_year_cat_1940-1950 | \n",
+ " release_year_cat_1950-1960 | \n",
+ " release_year_cat_1960-1970 | \n",
+ " release_year_cat_1970-1980 | \n",
+ " release_year_cat_1980-1990 | \n",
+ " ... | \n",
+ " directors_ярив хоровиц | \n",
+ " directors_ярон зильберман | \n",
+ " directors_ярополк лапшин | \n",
+ " directors_ярослав лупий | \n",
+ " directors_ярроу чейни, скотт моужер | \n",
+ " directors_ясина сезар | \n",
+ " directors_ясуоми умэцу | \n",
+ " directors_ёдзи фукуяма, ацуко фукусима, николас де креси, синъитиро ватанабэ, сёдзи кавамори | \n",
+ " directors_ёлкин туйчиев | \n",
+ " directors_ён сан-хо | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 10711 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " ... | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2508 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " ... | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 10716 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " ... | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 7868 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " ... | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 16268 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " ... | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
5 rows × 8589 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " item_id content_type_film content_type_series \\\n",
+ "0 10711 True False \n",
+ "1 2508 True False \n",
+ "2 10716 True False \n",
+ "3 7868 True False \n",
+ "4 16268 True False \n",
+ "\n",
+ " release_year_cat_1920-1930 release_year_cat_1930-1940 \\\n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False False \n",
+ "\n",
+ " release_year_cat_1940-1950 release_year_cat_1950-1960 \\\n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False False \n",
+ "\n",
+ " release_year_cat_1960-1970 release_year_cat_1970-1980 \\\n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False True \n",
+ "\n",
+ " release_year_cat_1980-1990 ... directors_ярив хоровиц \\\n",
+ "0 False ... False \n",
+ "1 False ... False \n",
+ "2 False ... False \n",
+ "3 False ... False \n",
+ "4 False ... False \n",
+ "\n",
+ " directors_ярон зильберман directors_ярополк лапшин \\\n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False False \n",
+ "\n",
+ " directors_ярослав лупий directors_ярроу чейни, скотт моужер \\\n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False False \n",
+ "\n",
+ " directors_ясина сезар directors_ясуоми умэцу \\\n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False False \n",
+ "\n",
+ " directors_ёдзи фукуяма, ацуко фукусима, николас де креси, синъитиро ватанабэ, сёдзи кавамори \\\n",
+ "0 False \n",
+ "1 False \n",
+ "2 False \n",
+ "3 False \n",
+ "4 False \n",
+ "\n",
+ " directors_ёлкин туйчиев directors_ён сан-хо \n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False False \n",
+ "\n",
+ "[5 rows x 8589 columns]"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "item_cat_feats = ['content_type', 'release_year_cat',\n",
+ " 'for_kids', 'age_rating', \n",
+ " 'studios', 'countries', 'directors']\n",
+ "\n",
+ "items_ohe_df = items_df.item_id\n",
+ "\n",
+ "for feat in item_cat_feats:\n",
+ " ohe_feat_df = pd.get_dummies(items_df[feat], prefix=feat)\n",
+ " items_ohe_df = pd.concat([items_ohe_df, ohe_feat_df], axis=1) \n",
+ "\n",
+ "items_ohe_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:30.294678Z",
+ "iopub.status.busy": "2023-01-22T16:23:30.294379Z",
+ "iopub.status.idle": "2023-01-22T16:23:30.316137Z",
+ "shell.execute_reply": "2023-01-22T16:23:30.314916Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:30.294651Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item_id | \n",
+ " content_type | \n",
+ " title | \n",
+ " title_orig | \n",
+ " genres | \n",
+ " countries | \n",
+ " for_kids | \n",
+ " age_rating | \n",
+ " studios | \n",
+ " directors | \n",
+ " actors | \n",
+ " description | \n",
+ " keywords | \n",
+ " release_year_cat | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 10711 | \n",
+ " film | \n",
+ " поговори с ней | \n",
+ " Hable con ella | \n",
+ " драмы, зарубежные, детективы, мелодрамы | \n",
+ " испания | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " педро альмодовар | \n",
+ " Адольфо Фернандес, Ана Фернандес, Дарио Гранди... | \n",
+ " Мелодрама легендарного Педро Альмодовара «Пого... | \n",
+ " Поговори, ней, 2002, Испания, друзья, любовь, ... | \n",
+ " 2000-2010 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2508 | \n",
+ " film | \n",
+ " голые перцы | \n",
+ " Search Party | \n",
+ " зарубежные, приключения, комедии | \n",
+ " сша | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " скот армстронг | \n",
+ " Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... | \n",
+ " Уморительная современная комедия на популярную... | \n",
+ " Голые, перцы, 2014, США, друзья, свадьбы, прео... | \n",
+ " 2010-2020 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 10716 | \n",
+ " film | \n",
+ " тактическая сила | \n",
+ " Tactical Force | \n",
+ " криминал, зарубежные, триллеры, боевики, комедии | \n",
+ " канада | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " адам п. калтраро | \n",
+ " Адриан Холмс, Даррен Шалави, Джерри Вассерман,... | \n",
+ " Профессиональный рестлер Стив Остин («Все или ... | \n",
+ " Тактическая, сила, 2011, Канада, бандиты, ганг... | \n",
+ " 2010-2020 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 7868 | \n",
+ " film | \n",
+ " 45 лет | \n",
+ " 45 Years | \n",
+ " драмы, зарубежные, мелодрамы | \n",
+ " великобритания | \n",
+ " False | \n",
+ " 16.0 | \n",
+ " unknown | \n",
+ " эндрю хэй | \n",
+ " Александра Риддлстон-Барретт, Джеральдин Джейм... | \n",
+ " Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей... | \n",
+ " 45, лет, 2015, Великобритания, брак, жизнь, лю... | \n",
+ " 2010-2020 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 16268 | \n",
+ " film | \n",
+ " все решает мгновение | \n",
+ " NaN | \n",
+ " драмы, спорт, советские, мелодрамы | \n",
+ " ссср | \n",
+ " False | \n",
+ " 12.0 | \n",
+ " ленфильм | \n",
+ " виктор садовский | \n",
+ " Александр Абдулов, Александр Демьяненко, Алекс... | \n",
+ " Расчетливая чаровница из советского кинохита «... | \n",
+ " Все, решает, мгновение, 1978, СССР, сильные, ж... | \n",
+ " 1970-1980 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " item_id content_type title title_orig \\\n",
+ "0 10711 film поговори с ней Hable con ella \n",
+ "1 2508 film голые перцы Search Party \n",
+ "2 10716 film тактическая сила Tactical Force \n",
+ "3 7868 film 45 лет 45 Years \n",
+ "4 16268 film все решает мгновение NaN \n",
+ "\n",
+ " genres countries for_kids \\\n",
+ "0 драмы, зарубежные, детективы, мелодрамы испания False \n",
+ "1 зарубежные, приключения, комедии сша False \n",
+ "2 криминал, зарубежные, триллеры, боевики, комедии канада False \n",
+ "3 драмы, зарубежные, мелодрамы великобритания False \n",
+ "4 драмы, спорт, советские, мелодрамы ссср False \n",
+ "\n",
+ " age_rating studios directors \\\n",
+ "0 16.0 unknown педро альмодовар \n",
+ "1 16.0 unknown скот армстронг \n",
+ "2 16.0 unknown адам п. калтраро \n",
+ "3 16.0 unknown эндрю хэй \n",
+ "4 12.0 ленфильм виктор садовский \n",
+ "\n",
+ " actors \\\n",
+ "0 Адольфо Фернандес, Ана Фернандес, Дарио Гранди... \n",
+ "1 Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ... \n",
+ "2 Адриан Холмс, Даррен Шалави, Джерри Вассерман,... \n",
+ "3 Александра Риддлстон-Барретт, Джеральдин Джейм... \n",
+ "4 Александр Абдулов, Александр Демьяненко, Алекс... \n",
+ "\n",
+ " description \\\n",
+ "0 Мелодрама легендарного Педро Альмодовара «Пого... \n",
+ "1 Уморительная современная комедия на популярную... \n",
+ "2 Профессиональный рестлер Стив Остин («Все или ... \n",
+ "3 Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей... \n",
+ "4 Расчетливая чаровница из советского кинохита «... \n",
+ "\n",
+ " keywords release_year_cat \n",
+ "0 Поговори, ней, 2002, Испания, друзья, любовь, ... 2000-2010 \n",
+ "1 Голые, перцы, 2014, США, друзья, свадьбы, прео... 2010-2020 \n",
+ "2 Тактическая, сила, 2011, Канада, бандиты, ганг... 2010-2020 \n",
+ "3 45, лет, 2015, Великобритания, брак, жизнь, лю... 2010-2020 \n",
+ "4 Все, решает, мгновение, 1978, СССР, сильные, ж... 1970-1980 "
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Добавим текстовые фичи\n",
+ "С помощью TFIDFVectorizer получим эмбеддинги следующих колонок: genres, description, keywords"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:30.318830Z",
+ "iopub.status.busy": "2023-01-22T16:23:30.318190Z",
+ "iopub.status.idle": "2023-01-22T16:23:30.335666Z",
+ "shell.execute_reply": "2023-01-22T16:23:30.334811Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:30.318792Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from sklearn.feature_extraction.text import TfidfVectorizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:30.338779Z",
+ "iopub.status.busy": "2023-01-22T16:23:30.338019Z",
+ "iopub.status.idle": "2023-01-22T16:23:31.386164Z",
+ "shell.execute_reply": "2023-01-22T16:23:31.385106Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:30.338741Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "for column in ['genres', 'keywords']:\n",
+ " tv = TfidfVectorizer(max_features = 500)\n",
+ " t = pd.DataFrame.sparse.from_spmatrix(tv.fit_transform(items_df[column]))\n",
+ " t.columns = [column + '_' + str(x) for x in t.columns]\n",
+ " items_ohe_df = pd.concat([items_ohe_df, t], axis = 1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:31.388255Z",
+ "iopub.status.busy": "2023-01-22T16:23:31.387847Z",
+ "iopub.status.idle": "2023-01-22T16:23:31.483917Z",
+ "shell.execute_reply": "2023-01-22T16:23:31.482751Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:31.388214Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item_id | \n",
+ " content_type_film | \n",
+ " content_type_series | \n",
+ " release_year_cat_1920-1930 | \n",
+ " release_year_cat_1930-1940 | \n",
+ " release_year_cat_1940-1950 | \n",
+ " release_year_cat_1950-1960 | \n",
+ " release_year_cat_1960-1970 | \n",
+ " release_year_cat_1970-1980 | \n",
+ " release_year_cat_1980-1990 | \n",
+ " ... | \n",
+ " keywords_490 | \n",
+ " keywords_491 | \n",
+ " keywords_492 | \n",
+ " keywords_493 | \n",
+ " keywords_494 | \n",
+ " keywords_495 | \n",
+ " keywords_496 | \n",
+ " keywords_497 | \n",
+ " keywords_498 | \n",
+ " keywords_499 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 10711 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2508 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 10716 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 7868 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 16268 | \n",
+ " True | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " False | \n",
+ " True | \n",
+ " False | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
5 rows × 9197 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " item_id content_type_film content_type_series \\\n",
+ "0 10711 True False \n",
+ "1 2508 True False \n",
+ "2 10716 True False \n",
+ "3 7868 True False \n",
+ "4 16268 True False \n",
+ "\n",
+ " release_year_cat_1920-1930 release_year_cat_1930-1940 \\\n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False False \n",
+ "\n",
+ " release_year_cat_1940-1950 release_year_cat_1950-1960 \\\n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False False \n",
+ "\n",
+ " release_year_cat_1960-1970 release_year_cat_1970-1980 \\\n",
+ "0 False False \n",
+ "1 False False \n",
+ "2 False False \n",
+ "3 False False \n",
+ "4 False True \n",
+ "\n",
+ " release_year_cat_1980-1990 ... keywords_490 keywords_491 keywords_492 \\\n",
+ "0 False ... 0.0 0.0 0.0 \n",
+ "1 False ... 0.0 0.0 0.0 \n",
+ "2 False ... 0.0 0.0 0.0 \n",
+ "3 False ... 0.0 0.0 0.0 \n",
+ "4 False ... 0.0 0.0 0.0 \n",
+ "\n",
+ " keywords_493 keywords_494 keywords_495 keywords_496 keywords_497 \\\n",
+ "0 0.0 0.0 0.0 0.0 0.0 \n",
+ "1 0.0 0.0 0.0 0.0 0.0 \n",
+ "2 0.0 0.0 0.0 0.0 0.0 \n",
+ "3 0.0 0.0 0.0 0.0 0.0 \n",
+ "4 0.0 0.0 0.0 0.0 0.0 \n",
+ "\n",
+ " keywords_498 keywords_499 \n",
+ "0 0.0 0.0 \n",
+ "1 0.0 0.0 \n",
+ "2 0.0 0.0 \n",
+ "3 0.0 0.0 \n",
+ "4 0.0 0.0 \n",
+ "\n",
+ "[5 rows x 9197 columns]"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_ohe_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cc595c20"
+ },
+ "source": [
+ "## Сделаем матрицу взаимодействий"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:37.898427Z",
+ "start_time": "2021-10-28T18:40:37.864067Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:31.486206Z",
+ "iopub.status.busy": "2023-01-22T16:23:31.485812Z",
+ "iopub.status.idle": "2023-01-22T16:23:31.604748Z",
+ "shell.execute_reply": "2023-01-22T16:23:31.603679Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:31.486170Z"
+ },
+ "id": "79c9bca3",
+ "outputId": "6f6148e0-8de7-4ffc-82d9-cf396db1ed98"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "item_id\n",
+ "10440 202457\n",
+ "15297 193123\n",
+ "9728 132865\n",
+ "13865 122119\n",
+ "4151 91167\n",
+ " ... \n",
+ "8076 1\n",
+ "8954 1\n",
+ "15664 1\n",
+ "818 1\n",
+ "10542 1\n",
+ "Name: count, Length: 15706, dtype: int64"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "interactions_df.item_id.value_counts()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YAfqm8asrBfG"
+ },
+ "source": [
+ "В датасете взаимодействий есть непопулярные фильмы и малоактивные пользователи. Кроме того, в таблице взаимодействий есть события с низким качеством взаимодействия - когда юзер начал смотреть фильм, но вскоре после начала просмотра выключил.\n",
+ "\n",
+ "Отфильтруем такие события*, малоактивных юзеров и непопулярные фильмы.\n",
+ "\n",
+ "Можете не фильтровать такие события, тогда у вас будет больше негативных примеров."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:38.103819Z",
+ "start_time": "2021-10-28T18:40:38.070117Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:31.606489Z",
+ "iopub.status.busy": "2023-01-22T16:23:31.606197Z",
+ "iopub.status.idle": "2023-01-22T16:23:31.985392Z",
+ "shell.execute_reply": "2023-01-22T16:23:31.984254Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:31.606462Z"
+ },
+ "id": "17334e80",
+ "outputId": "bfbe26dd-7778-42ad-c5dd-283635fcafa6"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "user_id\n",
+ "416206 1341\n",
+ "1010539 764\n",
+ "555233 685\n",
+ "11526 676\n",
+ "409259 625\n",
+ " ... \n",
+ "45493 1\n",
+ "615194 1\n",
+ "96848 1\n",
+ "425823 1\n",
+ "697262 1\n",
+ "Name: count, Length: 962179, dtype: int64"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "interactions_df.user_id.value_counts()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:39.717096Z",
+ "start_time": "2021-10-28T18:40:38.759740Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:31.987509Z",
+ "iopub.status.busy": "2023-01-22T16:23:31.986995Z",
+ "iopub.status.idle": "2023-01-22T16:23:34.897911Z",
+ "shell.execute_reply": "2023-01-22T16:23:34.896578Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:31.987469Z"
+ },
+ "id": "076e4ebc",
+ "outputId": "85c15fd2-12bb-478c-e00f-4f2b7bbcd6ab"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "N users before: 962179\n",
+ "N items before: 15706\n",
+ "\n",
+ "N users after: 79515\n",
+ "N items after: 6901\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"N users before: {interactions_df.user_id.nunique()}\")\n",
+ "print(f\"N items before: {interactions_df.item_id.nunique()}\\n\")\n",
+ "\n",
+ "# отфильтруем все события взаимодействий, в которых пользователь посмотрел\n",
+ "# фильм менее чем на 10 процентов\n",
+ "interactions_df = interactions_df[interactions_df.watched_pct > 10]\n",
+ "\n",
+ "# соберем всех пользователей, которые посмотрели \n",
+ "# больше 10 фильмов (можете выбрать другой порог)\n",
+ "valid_users = []\n",
+ "\n",
+ "c = Counter(interactions_df.user_id)\n",
+ "for user_id, entries in c.most_common():\n",
+ " if entries > 10:\n",
+ " valid_users.append(user_id)\n",
+ "\n",
+ "# и соберем все фильмы, которые посмотрели больше 10 пользователей\n",
+ "valid_items = []\n",
+ "\n",
+ "c = Counter(interactions_df.item_id)\n",
+ "for item_id, entries in c.most_common():\n",
+ " if entries > 10:\n",
+ " valid_items.append(item_id)\n",
+ "\n",
+ "# отбросим непопулярные фильмы и неактивных юзеров\n",
+ "interactions_df = interactions_df[interactions_df.user_id.isin(valid_users)]\n",
+ "interactions_df = interactions_df[interactions_df.item_id.isin(valid_items)]\n",
+ "\n",
+ "print(f\"N users after: {interactions_df.user_id.nunique()}\")\n",
+ "print(f\"N items after: {interactions_df.item_id.nunique()}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a9163fb2"
+ },
+ "source": [
+ "После фильтрации может получиться так, что некоторые айтемы/юзеры есть в датасете взаимодействий, но при этом они отсутствуют в датасетах айтемов/юзеров или наоборот. Поэтому найдем id айтемов и id юзеров, которые есть во всех датасетах и оставим только их."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:40.231703Z",
+ "start_time": "2021-10-28T18:40:39.718626Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:34.900180Z",
+ "iopub.status.busy": "2023-01-22T16:23:34.899760Z",
+ "iopub.status.idle": "2023-01-22T16:23:36.064882Z",
+ "shell.execute_reply": "2023-01-22T16:23:36.063765Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:34.900142Z"
+ },
+ "id": "d55848e1",
+ "outputId": "48609a0b-06b9-4a5e-f8f6-061db1c6dcb2"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "65974\n",
+ "6901\n"
+ ]
+ }
+ ],
+ "source": [
+ "common_users = set(interactions_df.user_id.unique()).intersection(set(users_ohe_df.user_id.unique()))\n",
+ "common_items = set(interactions_df.item_id.unique()).intersection(set(items_ohe_df.item_id.unique()))\n",
+ "\n",
+ "print(len(common_users))\n",
+ "print(len(common_items))\n",
+ "\n",
+ "interactions_df = interactions_df[interactions_df.item_id.isin(common_items)]\n",
+ "interactions_df = interactions_df[interactions_df.user_id.isin(common_users)]\n",
+ "\n",
+ "items_ohe_df = items_ohe_df[items_ohe_df.item_id.isin(common_items)]\n",
+ "users_ohe_df = users_ohe_df[users_ohe_df.user_id.isin(common_users)]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1e8b9480"
+ },
+ "source": [
+ "\n",
+ "Соберем взаимодействия в матрицу user*item так, чтобы в строках этой матрицы были user_id, в столбцах - item_id, а на пересечениях строк и столбцов - единица, если пользователь взаимодействовал с айтемом и ноль, если нет.\n",
+ "\n",
+ "Такую матрицу удобно собирать в numpy array, однако нужно помнить, что numpy array индексируется порядковыми индексами, а нам же удобнее использовать item_id и user_id.\n",
+ "\n",
+ "Создадим некие внутренние индексы для user_id и item_id - uid и iid. Для этого просто соберем все user_id и item_id и пронумеруем их по порядку."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:40:40.346587Z",
+ "start_time": "2021-10-28T18:40:40.233046Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:36.066990Z",
+ "iopub.status.busy": "2023-01-22T16:23:36.066574Z",
+ "iopub.status.idle": "2023-01-22T16:23:36.211726Z",
+ "shell.execute_reply": "2023-01-22T16:23:36.210597Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:36.066949Z"
+ },
+ "id": "81679fb0",
+ "outputId": "0c6bf7ce-1ea0-46c2-9d70-42b32bf08c7e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[0, 1, 2, 3, 4]\n",
+ "[0, 1, 2, 3, 4]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " last_watch_dt | \n",
+ " total_dur | \n",
+ " watched_pct | \n",
+ " uid | \n",
+ " iid | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 176549 | \n",
+ " 9506 | \n",
+ " 2021-05-11 | \n",
+ " 4250 | \n",
+ " 72 | \n",
+ " 10616 | \n",
+ " 3944 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 699317 | \n",
+ " 1659 | \n",
+ " 2021-05-29 | \n",
+ " 8317 | \n",
+ " 100 | \n",
+ " 42131 | \n",
+ " 675 | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 1016458 | \n",
+ " 354 | \n",
+ " 2021-08-14 | \n",
+ " 1672 | \n",
+ " 25 | \n",
+ " 61024 | \n",
+ " 139 | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 884009 | \n",
+ " 693 | \n",
+ " 2021-08-04 | \n",
+ " 703 | \n",
+ " 14 | \n",
+ " 53150 | \n",
+ " 279 | \n",
+ "
\n",
+ " \n",
+ " | 14 | \n",
+ " 5324 | \n",
+ " 8437 | \n",
+ " 2021-04-18 | \n",
+ " 6598 | \n",
+ " 92 | \n",
+ " 310 | \n",
+ " 3485 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id item_id last_watch_dt total_dur watched_pct uid iid\n",
+ "0 176549 9506 2021-05-11 4250 72 10616 3944\n",
+ "1 699317 1659 2021-05-29 8317 100 42131 675\n",
+ "6 1016458 354 2021-08-14 1672 25 61024 139\n",
+ "7 884009 693 2021-08-04 703 14 53150 279\n",
+ "14 5324 8437 2021-04-18 6598 92 310 3485"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "interactions_df[\"uid\"] = interactions_df[\"user_id\"].astype(\"category\")\n",
+ "interactions_df[\"uid\"] = interactions_df[\"uid\"].cat.codes\n",
+ "\n",
+ "interactions_df[\"iid\"] = interactions_df[\"item_id\"].astype(\"category\")\n",
+ "interactions_df[\"iid\"] = interactions_df[\"iid\"].cat.codes\n",
+ "\n",
+ "print(sorted(interactions_df.iid.unique())[:5])\n",
+ "print(sorted(interactions_df.uid.unique())[:5])\n",
+ "interactions_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "61c855e5"
+ },
+ "source": [
+ "Отнормируем матрицу взаимодействий"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:36.214161Z",
+ "iopub.status.busy": "2023-01-22T16:23:36.213276Z",
+ "iopub.status.idle": "2023-01-22T16:23:36.223246Z",
+ "shell.execute_reply": "2023-01-22T16:23:36.222069Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:36.214121Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0 3944\n",
+ "1 675\n",
+ "6 139\n",
+ "7 279\n",
+ "14 3485\n",
+ " ... \n",
+ "5476218 169\n",
+ "5476224 923\n",
+ "5476226 5610\n",
+ "5476239 2929\n",
+ "5476249 6766\n",
+ "Name: iid, Length: 1463641, dtype: int16"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "interactions_df.iid"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:03.360248Z",
+ "start_time": "2021-10-28T18:40:40.348057Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:36.225520Z",
+ "iopub.status.busy": "2023-01-22T16:23:36.224590Z",
+ "iopub.status.idle": "2023-01-22T16:23:43.629733Z",
+ "shell.execute_reply": "2023-01-22T16:23:43.628568Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:36.225480Z"
+ },
+ "id": "3feced70"
+ },
+ "outputs": [],
+ "source": [
+ "interactions_vec = np.zeros((interactions_df.uid.nunique(), \n",
+ " interactions_df.iid.nunique())) \n",
+ "\n",
+ "for user_id, item_id in zip(interactions_df.uid, interactions_df.iid):\n",
+ " interactions_vec[user_id, item_id] += 1\n",
+ "\n",
+ "\n",
+ "res = interactions_vec.sum(axis=1)\n",
+ "for i in range(len(interactions_vec)):\n",
+ " interactions_vec[i] /= res[i]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:03.416061Z",
+ "start_time": "2021-10-28T18:41:03.363462Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:43.634195Z",
+ "iopub.status.busy": "2023-01-22T16:23:43.631362Z",
+ "iopub.status.idle": "2023-01-22T16:23:43.711673Z",
+ "shell.execute_reply": "2023-01-22T16:23:43.710586Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:43.634161Z"
+ },
+ "id": "9f5ec90f",
+ "outputId": "9acdfe45-aa4e-4a64-ffdc-1a750390ae84"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6897\n",
+ "6901\n",
+ "65974\n",
+ "65974\n",
+ "{11805, 9788, 11501, 1734}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(interactions_df.item_id.nunique())\n",
+ "print(items_ohe_df.item_id.nunique())\n",
+ "print(interactions_df.user_id.nunique())\n",
+ "print(users_ohe_df.user_id.nunique())\n",
+ "\n",
+ "print(set(items_ohe_df.item_id.unique()) - set(interactions_df.item_id.unique()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:43.713551Z",
+ "iopub.status.busy": "2023-01-22T16:23:43.713196Z",
+ "iopub.status.idle": "2023-01-22T16:23:44.238808Z",
+ "shell.execute_reply": "2023-01-22T16:23:44.237691Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:43.713517Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "items_ohe_df = items_ohe_df[~items_ohe_df.item_id.isin([11805, 9788, 11501, 1734])]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "19e69bae"
+ },
+ "source": [
+ "Для того, чтобы можно было удобно превратить iid/uid в item_id/user_id и наоборот соберем словари \n",
+ "\n",
+ "{iid: item_id}, {uid: user_id} и {item_id: iid}, {user_id: uid}."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:03.637495Z",
+ "start_time": "2021-10-28T18:41:03.417544Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:44.243767Z",
+ "iopub.status.busy": "2023-01-22T16:23:44.243422Z",
+ "iopub.status.idle": "2023-01-22T16:23:44.817126Z",
+ "shell.execute_reply": "2023-01-22T16:23:44.816088Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:44.243739Z"
+ },
+ "id": "c8a84024"
+ },
+ "outputs": [],
+ "source": [
+ "iid_to_item_id = interactions_df[[\"iid\", \"item_id\"]].drop_duplicates().set_index(\"iid\").to_dict()[\"item_id\"]\n",
+ "item_id_to_iid = interactions_df[[\"iid\", \"item_id\"]].drop_duplicates().set_index(\"item_id\").to_dict()[\"iid\"]\n",
+ "\n",
+ "uid_to_user_id = interactions_df[[\"uid\", \"user_id\"]].drop_duplicates().set_index(\"uid\").to_dict()[\"user_id\"]\n",
+ "user_id_to_uid = interactions_df[[\"uid\", \"user_id\"]].drop_duplicates().set_index(\"user_id\").to_dict()[\"uid\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "48ca5204"
+ },
+ "source": [
+ "И проиндексируем датасеты users_ohe_df и items_ohe_df по внутренним айди:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:03.744883Z",
+ "start_time": "2021-10-28T18:41:03.638719Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:44.819593Z",
+ "iopub.status.busy": "2023-01-22T16:23:44.818859Z",
+ "iopub.status.idle": "2023-01-22T16:23:44.930257Z",
+ "shell.execute_reply": "2023-01-22T16:23:44.929032Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:44.819553Z"
+ },
+ "id": "4c4980ac"
+ },
+ "outputs": [],
+ "source": [
+ "items_ohe_df[\"iid\"] = items_ohe_df[\"item_id\"].apply(lambda x: item_id_to_iid[x])\n",
+ "items_ohe_df = items_ohe_df.set_index(\"iid\")\n",
+ "\n",
+ "users_ohe_df[\"uid\"] = users_ohe_df[\"user_id\"].apply(lambda x: user_id_to_uid[x])\n",
+ "users_ohe_df = users_ohe_df.set_index(\"uid\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:03.749717Z",
+ "start_time": "2021-10-28T18:41:03.746067Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:23:44.932306Z",
+ "iopub.status.busy": "2023-01-22T16:23:44.931684Z",
+ "iopub.status.idle": "2023-01-22T16:23:44.939719Z",
+ "shell.execute_reply": "2023-01-22T16:23:44.938755Z",
+ "shell.execute_reply.started": "2023-01-22T16:23:44.932267Z"
+ },
+ "id": "22c26d39"
+ },
+ "outputs": [],
+ "source": [
+ "def triplet_loss(y_true, y_pred, n_dims=128, alpha=0.4):\n",
+ " # будем ожидать, что на вход функции прилетит три сконкатенированных \n",
+ " # вектора - вектор юзера и два вектора айтема\n",
+ " anchor = y_pred[:, 0:n_dims]\n",
+ " positive = y_pred[:, n_dims:n_dims*2]\n",
+ " negative = y_pred[:, n_dims*2:n_dims*3]\n",
+ "\n",
+ " # считаем расстояния от вектора юзера до вектора хорошего айтема\n",
+ " pos_dist = K.sum(K.square(anchor - positive), axis=1)\n",
+ " # и до плохого\n",
+ " neg_dist = K.sum(K.square(anchor - negative), axis=1)\n",
+ "\n",
+ " # считаем лосс\n",
+ " basic_loss = pos_dist - neg_dist + alpha\n",
+ " loss = K.maximum(basic_loss, 0.0) # возвращаем ноль, если лосс отрицательный\n",
+ " \n",
+ " return loss\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T19:19:05.615364Z",
+ "start_time": "2021-10-28T19:19:05.612463Z"
+ },
+ "id": "4de262b4"
+ },
+ "source": [
+ "Попробуйте другие лоссы, например, BPR Triplet loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:18:11.520568Z",
+ "iopub.status.busy": "2023-01-22T16:18:11.519791Z",
+ "iopub.status.idle": "2023-01-22T16:18:11.535194Z",
+ "shell.execute_reply": "2023-01-22T16:18:11.533962Z",
+ "shell.execute_reply.started": "2023-01-22T16:18:11.520528Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def bpr_triplet_loss(y_true, y_pred, n_dims=128):\n",
+ " \n",
+ " from keras import backend as K\n",
+ " \n",
+ " anchor = y_pred[:, 0:n_dims]\n",
+ " positive = y_pred[:, n_dims:n_dims*2]\n",
+ " negative = y_pred[:, n_dims*2:n_dims*3]\n",
+ "\n",
+ " # BPR loss\n",
+ " loss = 1.0 - K.sigmoid(\n",
+ " K.sum(anchor * positive, axis=-1, keepdims=True) -\n",
+ " K.sum(anchor * negative, axis=-1, keepdims=True))\n",
+ "\n",
+ " return loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-23T11:20:03.327838Z",
+ "start_time": "2021-10-23T11:20:03.324389Z"
+ },
+ "id": "85d618b6"
+ },
+ "source": [
+ "## Генератор и семплирование\n",
+ "\n",
+ "- хорошим примером будет тот айтем, который был взят из датасета взаимодействий в соответствии с распределением просмотренных айтемов для этого юзера;\n",
+ "- Для негативного буду рандомно брать айтем из 100 наиболее непохожих по евклидовому расстоянию на положительный айтем по вектору жанр и ключевые слова, который человек при этом не смотрел \n",
+ "\n",
+ "Т. о., если например человек посмотрел целиком триллер, то в негативный для него должно попасть что-то вроде мелодрамы, при этом ключевые слова тоже будут сильно отличаться \n",
+ "\n",
+ "\n",
+ "Сформируем заранее следующий словарь - для каждого айтема: список из ста наиболее непохожих айтемов. Тогда в генераторе нужно будет взять рандомное значение их ста айтемов для положительного айтема. Если считать это в моменте работы генератора, то получается чрезвычайно долго, а здесь обращение к словарю - O(1), и взятие рандомного значения такое же по сложности, как в простом генераторе\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:37:25.302855Z",
+ "iopub.status.busy": "2023-01-22T15:37:25.302472Z",
+ "iopub.status.idle": "2023-01-22T15:37:32.215552Z",
+ "shell.execute_reply": "2023-01-22T15:37:32.214488Z",
+ "shell.execute_reply.started": "2023-01-22T15:37:25.302820Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 6897/6897 [00:02<00:00, 2513.96it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# формируем слоарь\n",
+ "\n",
+ "fts = items_ohe_df[[x for x in items_ohe_df if 'genre' in x or 'keywords' in x]]\n",
+ "\n",
+ "distances = pd.DataFrame(ED(fts))\n",
+ "distances.columns = list(fts.index)\n",
+ "distances.index = fts.index\n",
+ "\n",
+ "distance_dict = {}\n",
+ "for i in tqdm(distances.columns):\n",
+ " distance_dict[i] = list(distances[i].sort_values()[-100:].index)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:39:46.221189Z",
+ "iopub.status.busy": "2023-01-22T12:39:46.220459Z",
+ "iopub.status.idle": "2023-01-22T12:39:49.254755Z",
+ "shell.execute_reply": "2023-01-22T12:39:49.253714Z",
+ "shell.execute_reply.started": "2023-01-22T12:39:46.221147Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "iids_ = np.array(fts.index)\n",
+ "user_interactions = interactions_df.groupby(\"uid\")['iid'].apply(lambda x: np.array(x.unique())).to_dict()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:37:35.885246Z",
+ "iopub.status.busy": "2023-01-22T15:37:35.884025Z",
+ "iopub.status.idle": "2023-01-22T15:37:35.891070Z",
+ "shell.execute_reply": "2023-01-22T15:37:35.889851Z",
+ "shell.execute_reply.started": "2023-01-22T15:37:35.885203Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def get_negative_sample(pos_i, uid_i, distance_dict):\n",
+ " \n",
+ " neg_i = np.random.choice(distance_dict[pos_i])\n",
+ " \n",
+ " return neg_i"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T12:39:49.263942Z",
+ "iopub.status.busy": "2023-01-22T12:39:49.263310Z",
+ "iopub.status.idle": "2023-01-22T12:39:49.273779Z",
+ "shell.execute_reply": "2023-01-22T12:39:49.272870Z",
+ "shell.execute_reply.started": "2023-01-22T12:39:49.263906Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# функция для нахождения отрицательных item\n",
+ "\n",
+ "# очень долго работает \n",
+ "def get_negative_sample_old(pos_i, uid_i, fts, iids_, user_interactions):\n",
+ " \n",
+ " # айтемы , с которыми взаимодействовал юзер, их исключим\n",
+ " user_watched_items = user_interactions[uid_i]\n",
+ " \n",
+ " # векторы айтмов, которые не смотрел юзер, и по которым посчитаем евклидовы дистанции,\n",
+ " # чтобы найти самые непохожие на тот айтем, который юзер смотрел\n",
+ "\n",
+ " # из всего списка item вычитаем те, с которыми пользователь взаимодействовал\n",
+ " # список item которых пользователь не видел\n",
+ " inters = np.setdiff1d(iids_, user_watched_items, assume_unique=True)\n",
+ " \n",
+ " fts_ = fts.loc[inters].sample(n = 100)\n",
+ " \n",
+ " # вектор позитивного айтема \n",
+ " pos_item_fts = pd.DataFrame(fts.loc[pos_i, :]).T\n",
+ " \n",
+ " # считаем дистанции\n",
+ " dists = ED(fts_, pos_item_fts)\n",
+ " \n",
+ " # берем десять самых непохожих и непросмотренных юзером айтемов и из них случайно выбираем один \n",
+ " fts_['dists'] = dists\n",
+ " fts_ = fts_[['dists']]\n",
+ " neg_candidates = fts_.sort_values(by = \"dists\")[-10:].index\n",
+ " \n",
+ " neg_i = np.random.choice(neg_candidates)\n",
+ " \n",
+ " return neg_i"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:03.755386Z",
+ "start_time": "2021-10-28T18:41:03.750664Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:37:53.122995Z",
+ "iopub.status.busy": "2023-01-22T15:37:53.122612Z",
+ "iopub.status.idle": "2023-01-22T15:37:53.132222Z",
+ "shell.execute_reply": "2023-01-22T15:37:53.130866Z",
+ "shell.execute_reply.started": "2023-01-22T15:37:53.122960Z"
+ },
+ "id": "7829878b"
+ },
+ "outputs": [],
+ "source": [
+ "def generator(items, users, interactions, batch_size=1024):\n",
+ " while True:\n",
+ " uid_meta = []\n",
+ " uid_interaction = []\n",
+ " pos = []\n",
+ " neg = []\n",
+ " for _ in range(batch_size):\n",
+ " # берем рандомный uid\n",
+ " uid_i = randint(0, interactions.shape[0]-1)\n",
+ " # id хорошего айтема\n",
+ " pos_i = np.random.choice(range(interactions.shape[1]), p=interactions[uid_i])\n",
+ " # id плохого айтема\n",
+ " #neg_i = np.random.choice(range(interactions.shape[1]))\n",
+ " #neg_i = get_negative_sample_old(pos_i, uid_i, fts, iids_, user_interactions)\n",
+ " neg_i = get_negative_sample(pos_i, uid_i, distance_dict)\n",
+ " # фичи юзера\n",
+ " uid_meta.append(users.iloc[uid_i])\n",
+ " # вектор айтемов, с которыми юзер взаимодействовал\n",
+ " uid_interaction.append(interactions_vec[uid_i])\n",
+ " # фичи хорошего айтема\n",
+ " pos.append(items.iloc[pos_i])\n",
+ " # фичи плохого айтема\n",
+ " neg.append(items.iloc[neg_i])\n",
+ " \n",
+ " yield [np.array(uid_meta), np.array(uid_interaction), np.array(pos), np.array(neg)], [np.array(uid_meta), np.array(uid_interaction)]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:16.386864Z",
+ "start_time": "2021-10-28T18:41:03.756363Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:37:57.807501Z",
+ "iopub.status.busy": "2023-01-22T15:37:57.807136Z",
+ "iopub.status.idle": "2023-01-22T15:38:48.900316Z",
+ "shell.execute_reply": "2023-01-22T15:38:48.899211Z",
+ "shell.execute_reply.started": "2023-01-22T15:37:57.807471Z"
+ },
+ "id": "af9d3c3b",
+ "outputId": "1040f567-f64a-4ccb-91a8-48034694dfdc"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "вектор фичей юзера: (1024, 19)\n",
+ "вектор взаимодействий юзера с айтемами: (1024, 6897)\n",
+ "вектор 'хорошего' айтема: (1024, 9196)\n",
+ "вектор 'плохого' айтема: (1024, 9196)\n",
+ "\n",
+ "вектор фичей юзера: (1024, 19)\n",
+ "вектор взаимодействий юзера с айтемами: (1024, 6897)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# инициализируем генератор\n",
+ "gen = generator(items=items_ohe_df.drop([\"item_id\"], axis=1), \n",
+ " users=users_ohe_df.drop([\"user_id\"], axis=1), \n",
+ " interactions=interactions_vec, batch_size=1024)\n",
+ "\n",
+ "ret = next(gen)\n",
+ "\n",
+ "\n",
+ "print(f\"вектор фичей юзера: {ret[0][0].shape}\")\n",
+ "print(f\"вектор взаимодействий юзера с айтемами: {ret[0][1].shape}\")\n",
+ "print(f\"вектор 'хорошего' айтема: {ret[0][2].shape}\")\n",
+ "print(f\"вектор 'плохого' айтема: {ret[0][3].shape}\")\n",
+ "print()\n",
+ "print(f\"вектор фичей юзера: {ret[1][0].shape}\")\n",
+ "print(f\"вектор взаимодействий юзера с айтемами: {ret[1][1].shape}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8bcc3e80"
+ },
+ "source": [
+ "##Генаратор, который будет использовать информацию о качестве взаимодействия юзеров с айтемами для более репрезентативного сэмплирования\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:16.493030Z",
+ "start_time": "2021-10-28T18:41:16.388592Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:38:48.903047Z",
+ "iopub.status.busy": "2023-01-22T15:38:48.902586Z",
+ "iopub.status.idle": "2023-01-22T15:38:49.025937Z",
+ "shell.execute_reply": "2023-01-22T15:38:49.024831Z",
+ "shell.execute_reply.started": "2023-01-22T15:38:48.902992Z"
+ },
+ "id": "967b819f",
+ "outputId": "2f7a5885-dcb3-4ab8-80f8-57a21635595d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "N_FACTORS: 128\n",
+ "ITEM_MODEL_SHAPE: (9196,)\n",
+ "USER_META_MODEL_SHAPE: (19,)\n",
+ "USER_INTERACTION_MODEL_SHAPE: (6897,)\n"
+ ]
+ }
+ ],
+ "source": [
+ "N_FACTORS = 128\n",
+ "\n",
+ "# в датасетах есть столбец user_id/item_id, помним, что он не является фичей для обучения!\n",
+ "ITEM_MODEL_SHAPE = (items_ohe_df.drop([\"item_id\"], axis=1).shape[1], ) \n",
+ "USER_META_MODEL_SHAPE = (users_ohe_df.drop([\"user_id\"], axis=1).shape[1], )\n",
+ "\n",
+ "USER_INTERACTION_MODEL_SHAPE = (interactions_vec.shape[1], )\n",
+ "\n",
+ "print(f\"N_FACTORS: {N_FACTORS}\")\n",
+ "print(f\"ITEM_MODEL_SHAPE: {ITEM_MODEL_SHAPE}\")\n",
+ "print(f\"USER_META_MODEL_SHAPE: {USER_META_MODEL_SHAPE}\")\n",
+ "print(f\"USER_INTERACTION_MODEL_SHAPE: {USER_INTERACTION_MODEL_SHAPE}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:16.816499Z",
+ "start_time": "2021-10-28T18:41:16.494387Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:38:49.027755Z",
+ "iopub.status.busy": "2023-01-22T15:38:49.027467Z",
+ "iopub.status.idle": "2023-01-22T15:38:53.151538Z",
+ "shell.execute_reply": "2023-01-22T15:38:53.150595Z",
+ "shell.execute_reply.started": "2023-01-22T15:38:49.027729Z"
+ },
+ "id": "de649a01"
+ },
+ "outputs": [],
+ "source": [
+ "def item_model(n_factors=N_FACTORS):\n",
+ " # входной слой\n",
+ " inp = keras.layers.Input(shape=ITEM_MODEL_SHAPE)\n",
+ " \n",
+ " # полносвязный слой\n",
+ " layer_1 = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n",
+ " kernel_regularizer=keras.regularizers.l2(1e-6),\n",
+ " activity_regularizer=keras.regularizers.l2(l2=1e-6))(inp)\n",
+ "\n",
+ " # делаем residual connection - складываем два слоя, \n",
+ " # чтобы градиенты не затухали во время обучения\n",
+ " layer_2 = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n",
+ " kernel_regularizer=keras.regularizers.l2(1e-6),\n",
+ " activity_regularizer=keras.regularizers.l2(l2=1e-6))(layer_1)\n",
+ " \n",
+ " add = keras.layers.Add()([layer_1, layer_2])\n",
+ " \n",
+ " # выходной слой\n",
+ " out = keras.layers.Dense(N_FACTORS, activation='linear', use_bias=False,\n",
+ " kernel_regularizer=keras.regularizers.l2(1e-6),\n",
+ " activity_regularizer=keras.regularizers.l2(l2=1e-6))(add)\n",
+ " \n",
+ " return keras.models.Model(inp, out)\n",
+ "\n",
+ "\n",
+ "def user_model(n_factors=N_FACTORS):\n",
+ " # входной слой для вектора фичей юзера (из users_ohe_df)\n",
+ " inp_meta = keras.layers.Input(shape=USER_META_MODEL_SHAPE)\n",
+ " # входной слой для вектора просмотров (из iteractions_vec)\n",
+ " inp_interaction = keras.layers.Input(shape=USER_INTERACTION_MODEL_SHAPE)\n",
+ "\n",
+ " # полносвязный слой\n",
+ " layer_1_meta = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n",
+ " kernel_regularizer=keras.regularizers.l2(1e-6),\n",
+ " activity_regularizer=keras.regularizers.l2(l2=1e-6))(inp_meta)\n",
+ "\n",
+ " layer_1_interaction = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n",
+ " kernel_regularizer=keras.regularizers.l2(1e-6),\n",
+ " activity_regularizer=keras.regularizers.l2(l2=1e-6))(inp_interaction)\n",
+ "\n",
+ " # делаем residual connection - складываем два слоя,\n",
+ " # чтобы градиенты не затухали во время обучения\n",
+ " layer_2_meta = keras.layers.Dense(N_FACTORS, activation='elu', use_bias=False,\n",
+ " kernel_regularizer=keras.regularizers.l2(1e-6),\n",
+ " activity_regularizer=keras.regularizers.l2(l2=1e-6))(layer_1_meta)\n",
+ " \n",
+ "\n",
+ " add = keras.layers.Add()([layer_1_meta, layer_2_meta])\n",
+ " \n",
+ " # конкатенируем вектор фичей с вектором просмотров\n",
+ " concat_meta_interaction = keras.layers.Concatenate()([add, layer_1_interaction])\n",
+ " \n",
+ " # выходной слой\n",
+ " out = keras.layers.Dense(N_FACTORS, activation='linear', use_bias=False,\n",
+ " kernel_regularizer=keras.regularizers.l2(1e-6),\n",
+ " activity_regularizer=keras.regularizers.l2(l2=1e-6))(concat_meta_interaction)\n",
+ " \n",
+ " return keras.models.Model([inp_meta, inp_interaction], out)\n",
+ "\n",
+ "# инициализируем модели юзера и айтема\n",
+ "i2v = item_model()\n",
+ "u2v = user_model()\n",
+ "\n",
+ "# вход для вектора фичей юзера (из users_ohe_df)\n",
+ "ancor_meta_in = keras.layers.Input(shape=USER_META_MODEL_SHAPE)\n",
+ "# вход для вектора просмотра юзера (из interactions_vec)\n",
+ "ancor_interaction_in = keras.layers.Input(shape=USER_INTERACTION_MODEL_SHAPE)\n",
+ "\n",
+ "# вход для вектора \"хорошего\" айтема\n",
+ "pos_in = keras.layers.Input(shape=ITEM_MODEL_SHAPE)\n",
+ "# вход для вектора \"плохого\" айтема\n",
+ "neg_in = keras.layers.Input(shape=ITEM_MODEL_SHAPE)\n",
+ "\n",
+ "# получаем вектор юзера\n",
+ "ancor = u2v([ancor_meta_in, ancor_interaction_in])\n",
+ "# получаем вектор \"хорошего\" айтема\n",
+ "pos = i2v(pos_in)\n",
+ "# получаем вектор \"плохого\" айтема\n",
+ "neg = i2v(neg_in)\n",
+ "\n",
+ "# конкатенируем полученные векторы\n",
+ "res = keras.layers.Concatenate(name=\"concat_ancor_pos_neg\")([ancor, pos, neg])\n",
+ "\n",
+ "# собираем модель\n",
+ "model = keras.models.Model([ancor_meta_in, ancor_interaction_in, pos_in, neg_in], res)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:16.822662Z",
+ "start_time": "2021-10-28T18:41:16.817857Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:38:53.154784Z",
+ "iopub.status.busy": "2023-01-22T15:38:53.154419Z",
+ "iopub.status.idle": "2023-01-22T15:38:53.789679Z",
+ "shell.execute_reply": "2023-01-22T15:38:53.788675Z",
+ "shell.execute_reply.started": "2023-01-22T15:38:53.154748Z"
+ },
+ "id": "e912d920"
+ },
+ "outputs": [],
+ "source": [
+ "model_name = 'recsys_resnet_linear'\n",
+ "\n",
+ "# логируем процесс обучения в тензорборд\n",
+ "t_board = keras.callbacks.TensorBoard(log_dir=f'runs/{model_name}')\n",
+ "\n",
+ "# уменьшаем learning_rate, если лосс долго не уменьшается (в течение двух эпох)\n",
+ "decay = keras.callbacks.ReduceLROnPlateau(monitor='loss', patience=2, factor=0.8, verbose=1)\n",
+ "\n",
+ "# сохраняем модель после каждой эпохи, если лосс уменьшился\n",
+ "check = keras.callbacks.ModelCheckpoint(filepath=model_name + '/epoch{epoch}-{loss:.2f}.h5', monitor=\"loss\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:16.832365Z",
+ "start_time": "2021-10-28T18:41:16.824484Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:38:53.792105Z",
+ "iopub.status.busy": "2023-01-22T15:38:53.791371Z",
+ "iopub.status.idle": "2023-01-22T15:38:53.808624Z",
+ "shell.execute_reply": "2023-01-22T15:38:53.807732Z",
+ "shell.execute_reply.started": "2023-01-22T15:38:53.792041Z"
+ },
+ "id": "f95049f6"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.\n",
+ "WARNING:absl:`lr` is deprecated in Keras optimizer, please use `learning_rate` or use the legacy optimizer, e.g.,tf.keras.optimizers.legacy.Adam.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# компилируем модель, используем оптимайзер Adam и triplet loss\n",
+ "opt = keras.optimizers.Adam(lr=0.001)\n",
+ "model.compile(loss=triplet_loss, optimizer=opt)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:16.867472Z",
+ "start_time": "2021-10-28T18:41:16.833753Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:38:53.811821Z",
+ "iopub.status.busy": "2023-01-22T15:38:53.811155Z",
+ "iopub.status.idle": "2023-01-22T15:38:53.852098Z",
+ "shell.execute_reply": "2023-01-22T15:38:53.851090Z",
+ "shell.execute_reply.started": "2023-01-22T15:38:53.811786Z"
+ },
+ "id": "fb9382d0",
+ "outputId": "2eca9a17-1544-4e27-a483-b86d11391767"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model: \"model_3\"\n",
+ "__________________________________________________________________________________________________\n",
+ " Layer (type) Output Shape Param # Connected to \n",
+ "==================================================================================================\n",
+ " input_8 (InputLayer) [(None, 9196)] 0 [] \n",
+ " \n",
+ " dense_7 (Dense) (None, 128) 1177088 ['input_8[0][0]'] \n",
+ " \n",
+ " dense_8 (Dense) (None, 128) 16384 ['dense_7[0][0]'] \n",
+ " \n",
+ " add_2 (Add) (None, 128) 0 ['dense_7[0][0]', \n",
+ " 'dense_8[0][0]'] \n",
+ " \n",
+ " dense_9 (Dense) (None, 128) 16384 ['add_2[0][0]'] \n",
+ " \n",
+ "==================================================================================================\n",
+ "Total params: 1209856 (4.62 MB)\n",
+ "Trainable params: 1209856 (4.62 MB)\n",
+ "Non-trainable params: 0 (0.00 Byte)\n",
+ "__________________________________________________________________________________________________\n"
+ ]
+ }
+ ],
+ "source": [
+ "# модель айтема\n",
+ "item_model().summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:16.923402Z",
+ "start_time": "2021-10-28T18:41:16.868877Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:38:53.854198Z",
+ "iopub.status.busy": "2023-01-22T15:38:53.853594Z",
+ "iopub.status.idle": "2023-01-22T15:38:53.908177Z",
+ "shell.execute_reply": "2023-01-22T15:38:53.907222Z",
+ "shell.execute_reply.started": "2023-01-22T15:38:53.854161Z"
+ },
+ "id": "286149d1",
+ "outputId": "4284ba09-05ef-4963-c637-67e919701d19"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model: \"model_4\"\n",
+ "__________________________________________________________________________________________________\n",
+ " Layer (type) Output Shape Param # Connected to \n",
+ "==================================================================================================\n",
+ " input_9 (InputLayer) [(None, 19)] 0 [] \n",
+ " \n",
+ " dense_10 (Dense) (None, 128) 2432 ['input_9[0][0]'] \n",
+ " \n",
+ " dense_12 (Dense) (None, 128) 16384 ['dense_10[0][0]'] \n",
+ " \n",
+ " input_10 (InputLayer) [(None, 6897)] 0 [] \n",
+ " \n",
+ " add_3 (Add) (None, 128) 0 ['dense_10[0][0]', \n",
+ " 'dense_12[0][0]'] \n",
+ " \n",
+ " dense_11 (Dense) (None, 128) 882816 ['input_10[0][0]'] \n",
+ " \n",
+ " concatenate_1 (Concatenate (None, 256) 0 ['add_3[0][0]', \n",
+ " ) 'dense_11[0][0]'] \n",
+ " \n",
+ " dense_13 (Dense) (None, 128) 32768 ['concatenate_1[0][0]'] \n",
+ " \n",
+ "==================================================================================================\n",
+ "Total params: 934400 (3.56 MB)\n",
+ "Trainable params: 934400 (3.56 MB)\n",
+ "Non-trainable params: 0 (0.00 Byte)\n",
+ "__________________________________________________________________________________________________\n"
+ ]
+ }
+ ],
+ "source": [
+ "# модель юзера\n",
+ "user_model().summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T18:41:16.929341Z",
+ "start_time": "2021-10-28T18:41:16.924663Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:38:53.909970Z",
+ "iopub.status.busy": "2023-01-22T15:38:53.909370Z",
+ "iopub.status.idle": "2023-01-22T15:38:53.917202Z",
+ "shell.execute_reply": "2023-01-22T15:38:53.916103Z",
+ "shell.execute_reply.started": "2023-01-22T15:38:53.909934Z"
+ },
+ "id": "d9f25a3f",
+ "outputId": "6f9a3700-4420-4345-8331-82f7207b566b"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model: \"model_2\"\n",
+ "__________________________________________________________________________________________________\n",
+ " Layer (type) Output Shape Param # Connected to \n",
+ "==================================================================================================\n",
+ " input_4 (InputLayer) [(None, 19)] 0 [] \n",
+ " \n",
+ " input_5 (InputLayer) [(None, 6897)] 0 [] \n",
+ " \n",
+ " input_6 (InputLayer) [(None, 9196)] 0 [] \n",
+ " \n",
+ " input_7 (InputLayer) [(None, 9196)] 0 [] \n",
+ " \n",
+ " model_1 (Functional) (None, 128) 934400 ['input_4[0][0]', \n",
+ " 'input_5[0][0]'] \n",
+ " \n",
+ " model (Functional) (None, 128) 1209856 ['input_6[0][0]', \n",
+ " 'input_7[0][0]'] \n",
+ " \n",
+ " concat_ancor_pos_neg (Conc (None, 384) 0 ['model_1[0][0]', \n",
+ " atenate) 'model[0][0]', \n",
+ " 'model[1][0]'] \n",
+ " \n",
+ "==================================================================================================\n",
+ "Total params: 2144256 (8.18 MB)\n",
+ "Trainable params: 2144256 (8.18 MB)\n",
+ "Non-trainable params: 0 (0.00 Byte)\n",
+ "__________________________________________________________________________________________________\n"
+ ]
+ }
+ ],
+ "source": [
+ "# общая модель\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T19:15:21.657529Z",
+ "start_time": "2021-10-28T19:15:16.365923Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T15:38:53.919463Z",
+ "iopub.status.busy": "2023-01-22T15:38:53.918611Z",
+ "iopub.status.idle": "2023-01-22T16:17:01.448835Z",
+ "shell.execute_reply": "2023-01-22T16:17:01.447888Z",
+ "shell.execute_reply.started": "2023-01-22T15:38:53.919424Z"
+ },
+ "id": "99d50830",
+ "outputId": "cee25813-2173-460f-e6f2-024d75d1db08"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/30\n",
+ "100/100 [==============================] - 42s 418ms/step - loss: 0.4123 - lr: 0.0010\n",
+ "Epoch 2/30\n",
+ "100/100 [==============================] - 42s 425ms/step - loss: 0.2973 - lr: 0.0010\n",
+ "Epoch 3/30\n",
+ "100/100 [==============================] - 42s 423ms/step - loss: 0.2713 - lr: 0.0010\n",
+ "Epoch 4/30\n",
+ "100/100 [==============================] - 642s 6s/step - loss: 0.2240 - lr: 0.0010\n",
+ "Epoch 5/30\n",
+ "100/100 [==============================] - 41s 413ms/step - loss: 0.2200 - lr: 0.0010\n",
+ "Epoch 6/30\n",
+ "100/100 [==============================] - 41s 415ms/step - loss: 0.1929 - lr: 0.0010\n",
+ "Epoch 7/30\n",
+ "100/100 [==============================] - 42s 427ms/step - loss: 0.1727 - lr: 0.0010\n",
+ "Epoch 8/30\n",
+ "100/100 [==============================] - 42s 423ms/step - loss: 0.1849 - lr: 0.0010\n",
+ "Epoch 9/30\n",
+ "100/100 [==============================] - 41s 418ms/step - loss: 0.1594 - lr: 0.0010\n",
+ "Epoch 10/30\n",
+ "100/100 [==============================] - 42s 420ms/step - loss: 0.1485 - lr: 0.0010\n",
+ "Epoch 11/30\n",
+ "100/100 [==============================] - 41s 417ms/step - loss: 0.1523 - lr: 0.0010\n",
+ "Epoch 12/30\n",
+ "100/100 [==============================] - 41s 415ms/step - loss: 0.1328 - lr: 0.0010\n",
+ "Epoch 13/30\n",
+ "100/100 [==============================] - 41s 419ms/step - loss: 0.1407 - lr: 0.0010\n",
+ "Epoch 14/30\n",
+ "100/100 [==============================] - ETA: 0s - loss: 0.1524\n",
+ "Epoch 14: ReduceLROnPlateau reducing learning rate to 0.000800000037997961.\n",
+ "100/100 [==============================] - 42s 421ms/step - loss: 0.1524 - lr: 0.0010\n",
+ "Epoch 15/30\n",
+ "100/100 [==============================] - 42s 420ms/step - loss: 0.1304 - lr: 8.0000e-04\n",
+ "Epoch 16/30\n",
+ "100/100 [==============================] - 42s 421ms/step - loss: 0.1305 - lr: 8.0000e-04\n",
+ "Epoch 17/30\n",
+ "100/100 [==============================] - 41s 415ms/step - loss: 0.1299 - lr: 8.0000e-04\n",
+ "Epoch 18/30\n",
+ "100/100 [==============================] - 41s 416ms/step - loss: 0.1332 - lr: 8.0000e-04\n",
+ "Epoch 19/30\n",
+ "100/100 [==============================] - 578s 6s/step - loss: 0.1128 - lr: 8.0000e-04\n",
+ "Epoch 20/30\n",
+ "100/100 [==============================] - 298s 3s/step - loss: 0.1168 - lr: 8.0000e-04\n",
+ "Epoch 21/30\n",
+ "100/100 [==============================] - ETA: 0s - loss: 0.1145\n",
+ "Epoch 21: ReduceLROnPlateau reducing learning rate to 0.0006400000303983689.\n",
+ "100/100 [==============================] - 42s 424ms/step - loss: 0.1145 - lr: 8.0000e-04\n",
+ "Epoch 22/30\n",
+ "100/100 [==============================] - 42s 423ms/step - loss: 0.1230 - lr: 6.4000e-04\n",
+ "Epoch 23/30\n",
+ "100/100 [==============================] - 42s 421ms/step - loss: 0.1108 - lr: 6.4000e-04\n",
+ "Epoch 24/30\n",
+ "100/100 [==============================] - 42s 422ms/step - loss: 0.0976 - lr: 6.4000e-04\n",
+ "Epoch 25/30\n",
+ "100/100 [==============================] - 42s 422ms/step - loss: 0.1116 - lr: 6.4000e-04\n",
+ "Epoch 26/30\n",
+ "100/100 [==============================] - ETA: 0s - loss: 0.0996\n",
+ "Epoch 26: ReduceLROnPlateau reducing learning rate to 0.0005120000336319208.\n",
+ "100/100 [==============================] - 43s 430ms/step - loss: 0.0996 - lr: 6.4000e-04\n",
+ "Epoch 27/30\n",
+ "100/100 [==============================] - 43s 433ms/step - loss: 0.1010 - lr: 5.1200e-04\n",
+ "Epoch 28/30\n",
+ "100/100 [==============================] - ETA: 0s - loss: 0.0984\n",
+ "Epoch 28: ReduceLROnPlateau reducing learning rate to 0.00040960004553198815.\n",
+ "100/100 [==============================] - 42s 429ms/step - loss: 0.0984 - lr: 5.1200e-04\n",
+ "Epoch 29/30\n",
+ "100/100 [==============================] - 43s 430ms/step - loss: 0.0865 - lr: 4.0960e-04\n",
+ "Epoch 30/30\n",
+ "100/100 [==============================] - 43s 433ms/step - loss: 0.1049 - lr: 4.0960e-04\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 46,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# начинаем обучение, не забывая дропнуть столбцы item_id и user_id \n",
+ "# из датафреймов при инициализации генератора.\n",
+ "\n",
+ "# batch_size можно (и лучше) поставить побольше, если вы не органичены в ресурсах\n",
+ "\n",
+ "model.fit(generator(items=items_ohe_df.drop([\"item_id\"], axis=1), \n",
+ " users=users_ohe_df.drop([\"user_id\"], axis=1), \n",
+ " interactions=interactions_vec,\n",
+ " batch_size=16), \n",
+ " steps_per_epoch=100, \n",
+ " epochs=30, \n",
+ " initial_epoch=0,\n",
+ " callbacks=[decay, t_board, check]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:17:01.453483Z",
+ "iopub.status.busy": "2023-01-22T16:17:01.453198Z",
+ "iopub.status.idle": "2023-01-22T16:17:01.486783Z",
+ "shell.execute_reply": "2023-01-22T16:17:01.485812Z",
+ "shell.execute_reply.started": "2023-01-22T16:17:01.453458Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
+ ]
+ }
+ ],
+ "source": [
+ "i2v.save('i2v.hdf5')\n",
+ "u2v.save('u2v.hdf5')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T19:15:26.511958Z",
+ "start_time": "2021-10-28T19:15:26.151899Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:24:28.685695Z",
+ "iopub.status.busy": "2023-01-22T16:24:28.685290Z",
+ "iopub.status.idle": "2023-01-22T16:24:30.854186Z",
+ "shell.execute_reply": "2023-01-22T16:24:30.853120Z",
+ "shell.execute_reply.started": "2023-01-22T16:24:28.685657Z"
+ },
+ "id": "94d23f62",
+ "outputId": "4a500ea7-fa38-4455-a51a-2bd0113fa2f2"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1/1 [==============================] - 0s 129ms/step\n",
+ "1/1 [==============================] - 0s 29ms/step\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "array([[0.76927984]], dtype=float32)"
+ ]
+ },
+ "execution_count": 49,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# берем рандомного юзера\n",
+ "rand_uid = np.random.choice(list(users_ohe_df.index))\n",
+ "\n",
+ "# получаем фичи юзера и вектор его просмотров айтемов\n",
+ "user_meta_feats = users_ohe_df.drop([\"user_id\"], axis=1).iloc[rand_uid]\n",
+ "user_interaction_vec = interactions_vec[rand_uid]\n",
+ "\n",
+ "# берем рандомный айтем\n",
+ "rand_iid = np.random.choice(list(items_ohe_df.index))\n",
+ "# получаем фичи айтема\n",
+ "item_feats = items_ohe_df.drop([\"item_id\"], axis=1).iloc[rand_iid]\n",
+ "\n",
+ "# получаем вектор юзера\n",
+ "user_vec = u2v.predict([np.array(user_meta_feats).reshape(1, -1), \n",
+ " np.array(user_interaction_vec).reshape(1, -1)])\n",
+ "\n",
+ "# и вектор айтема\n",
+ "item_vec = i2v.predict(np.array(item_feats).reshape(1, -1))\n",
+ "\n",
+ "# считаем расстояние между вектором юзера и вектором айтема\n",
+ "from sklearn.metrics.pairwise import euclidean_distances as ED\n",
+ "\n",
+ "ED(user_vec, item_vec)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2021-10-28T19:15:28.951471Z",
+ "start_time": "2021-10-28T19:15:27.763367Z"
+ },
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:24:35.398767Z",
+ "iopub.status.busy": "2023-01-22T16:24:35.398342Z",
+ "iopub.status.idle": "2023-01-22T16:24:37.179114Z",
+ "shell.execute_reply": "2023-01-22T16:24:37.177336Z",
+ "shell.execute_reply.started": "2023-01-22T16:24:35.398731Z"
+ },
+ "id": "d537d3e8",
+ "outputId": "6bdce370-c348-4f0b-cd4f-3cbb0ec3d019"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "216/216 [==============================] - 0s 883us/step\n"
+ ]
+ }
+ ],
+ "source": [
+ "# получаем фичи всех айтемов\n",
+ "items_feats = items_ohe_df.drop([\"item_id\"], axis=1).to_numpy()\n",
+ "# получаем векторы всех айтемов\n",
+ "items_vecs = i2v.predict(items_feats)\n",
+ "\n",
+ "# считаем расстояния\n",
+ "dists = ED(user_vec, items_vecs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:24:37.200481Z",
+ "iopub.status.busy": "2023-01-22T16:24:37.199790Z",
+ "iopub.status.idle": "2023-01-22T16:24:37.219685Z",
+ "shell.execute_reply": "2023-01-22T16:24:37.218365Z",
+ "shell.execute_reply.started": "2023-01-22T16:24:37.200421Z"
+ },
+ "id": "udY36b_l0okL",
+ "outputId": "53287102-2434-490e-d668-b8085515d4b8"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(6897, 128)"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_vecs.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:24:38.051890Z",
+ "iopub.status.busy": "2023-01-22T16:24:38.051199Z",
+ "iopub.status.idle": "2023-01-22T16:24:38.063043Z",
+ "shell.execute_reply": "2023-01-22T16:24:38.061416Z",
+ "shell.execute_reply.started": "2023-01-22T16:24:38.051840Z"
+ },
+ "id": "XasFl6RN0snT"
+ },
+ "outputs": [],
+ "source": [
+ "users_meta_feats = users_ohe_df.drop([\"user_id\"], axis=1)\n",
+ "users_interaction_vec = interactions_vec"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:24:38.561257Z",
+ "iopub.status.busy": "2023-01-22T16:24:38.560146Z",
+ "iopub.status.idle": "2023-01-22T16:24:38.568144Z",
+ "shell.execute_reply": "2023-01-22T16:24:38.566777Z",
+ "shell.execute_reply.started": "2023-01-22T16:24:38.561176Z"
+ },
+ "id": "cntEZU450_MI",
+ "outputId": "c9aace32-281a-4b0e-8e0d-8b0f1088ce9b"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(65974, 19)"
+ ]
+ },
+ "execution_count": 53,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "users_meta_feats.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:24:40.475433Z",
+ "iopub.status.busy": "2023-01-22T16:24:40.472691Z",
+ "iopub.status.idle": "2023-01-22T16:24:40.484559Z",
+ "shell.execute_reply": "2023-01-22T16:24:40.483559Z",
+ "shell.execute_reply.started": "2023-01-22T16:24:40.475392Z"
+ },
+ "id": "kQ1EZolS1B1Y",
+ "outputId": "ca9dc5eb-5519-4c75-f1d8-4945941a46d1"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(65974, 6897)"
+ ]
+ },
+ "execution_count": 54,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "users_interaction_vec.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:24:40.786186Z",
+ "iopub.status.busy": "2023-01-22T16:24:40.785775Z",
+ "iopub.status.idle": "2023-01-22T16:24:40.797826Z",
+ "shell.execute_reply": "2023-01-22T16:24:40.796580Z",
+ "shell.execute_reply.started": "2023-01-22T16:24:40.786151Z"
+ },
+ "id": "hKU4MD7M1dp5",
+ "outputId": "bd79e8e2-8a82-4ff7-d1ac-849e3425c2c8"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(65974, 19)"
+ ]
+ },
+ "execution_count": 55,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "np.array(users_meta_feats).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:24:41.775332Z",
+ "iopub.status.busy": "2023-01-22T16:24:41.774265Z",
+ "iopub.status.idle": "2023-01-22T16:24:41.780836Z",
+ "shell.execute_reply": "2023-01-22T16:24:41.779665Z",
+ "shell.execute_reply.started": "2023-01-22T16:24:41.775281Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "del interactions_vec\n",
+ "del users_df, interactions_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:24:52.966474Z",
+ "iopub.status.busy": "2023-01-22T16:24:52.965002Z",
+ "iopub.status.idle": "2023-01-22T16:24:57.402446Z",
+ "shell.execute_reply": "2023-01-22T16:24:57.401009Z",
+ "shell.execute_reply.started": "2023-01-22T16:24:52.966417Z"
+ },
+ "id": "x16g5FM21XGJ",
+ "outputId": "9b43e3e6-f98b-466a-8b0d-c2aeb0ca724e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "625/625 [==============================] - 1s 809us/step\n",
+ "625/625 [==============================] - 0s 779us/step\n",
+ "812/812 [==============================] - 1s 763us/step\n"
+ ]
+ }
+ ],
+ "source": [
+ "users_vec_1 = u2v.predict([np.array(users_meta_feats.iloc[:20000]), \n",
+ " np.array(users_interaction_vec[:20000])])\n",
+ "users_vec_2 = u2v.predict([np.array(users_meta_feats.iloc[20000:40000]), \n",
+ " np.array(users_interaction_vec[20000:40000])])\n",
+ "users_vec_3 = u2v.predict([np.array(users_meta_feats.iloc[40000:]), \n",
+ " np.array(users_interaction_vec[40000:])])\n",
+ "users_vec = np.concatenate((users_vec_1, users_vec_2, users_vec_3))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:25:54.668894Z",
+ "iopub.status.busy": "2023-01-22T16:25:54.667629Z",
+ "iopub.status.idle": "2023-01-22T16:25:54.674606Z",
+ "shell.execute_reply": "2023-01-22T16:25:54.673447Z",
+ "shell.execute_reply.started": "2023-01-22T16:25:54.668856Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "del users_vec_1, users_vec_2, users_vec_3, users_interaction_vec"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 59,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:25:57.451831Z",
+ "iopub.status.busy": "2023-01-22T16:25:57.451189Z",
+ "iopub.status.idle": "2023-01-22T16:25:57.458982Z",
+ "shell.execute_reply": "2023-01-22T16:25:57.457745Z",
+ "shell.execute_reply.started": "2023-01-22T16:25:57.451795Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(65974, 128)"
+ ]
+ },
+ "execution_count": 59,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "users_vec.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:20:21.497209Z",
+ "iopub.status.busy": "2023-01-22T16:20:21.496765Z",
+ "iopub.status.idle": "2023-01-22T16:20:21.504388Z",
+ "shell.execute_reply": "2023-01-22T16:20:21.503250Z",
+ "shell.execute_reply.started": "2023-01-22T16:20:21.497158Z"
+ },
+ "id": "G4pntPu10ogl",
+ "outputId": "557b6f56-dff5-46e9-da97-569001c59c79"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(6897, 128)"
+ ]
+ },
+ "execution_count": 60,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_vecs.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:26:12.212846Z",
+ "iopub.status.busy": "2023-01-22T16:26:12.212440Z",
+ "iopub.status.idle": "2023-01-22T16:26:19.980077Z",
+ "shell.execute_reply": "2023-01-22T16:26:19.978704Z",
+ "shell.execute_reply.started": "2023-01-22T16:26:12.212812Z"
+ },
+ "id": "hnUX3Yte2Jcw"
+ },
+ "outputs": [],
+ "source": [
+ "dists = ED(users_vec, items_vecs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:26:33.221953Z",
+ "iopub.status.busy": "2023-01-22T16:26:33.220783Z",
+ "iopub.status.idle": "2023-01-22T16:26:33.231255Z",
+ "shell.execute_reply": "2023-01-22T16:26:33.229877Z",
+ "shell.execute_reply.started": "2023-01-22T16:26:33.221902Z"
+ },
+ "id": "MDgiwnnu2KHk",
+ "outputId": "ae9eeb6f-8a29-4195-8bdf-eeaa51b6d049"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(65974, 6897)"
+ ]
+ },
+ "execution_count": 62,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dists.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:26:36.257959Z",
+ "iopub.status.busy": "2023-01-22T16:26:36.257347Z",
+ "iopub.status.idle": "2023-01-22T16:26:45.531254Z",
+ "shell.execute_reply": "2023-01-22T16:26:45.530120Z",
+ "shell.execute_reply.started": "2023-01-22T16:26:36.257910Z"
+ },
+ "id": "Ru8IQwSV2UrB"
+ },
+ "outputs": [],
+ "source": [
+ "top10_iids_1 = np.argsort(dists[:20000], axis=1)[:,:10]\n",
+ "top10_iids_2 = np.argsort(dists[20000:40000], axis=1)[:,:10]\n",
+ "top10_iids_3 = np.argsort(dists[40000:], axis=1)[:,:10]\n",
+ "top10_iids = np.concatenate((top10_iids_1, top10_iids_2, top10_iids_3))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 64,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:28:25.544273Z",
+ "iopub.status.busy": "2023-01-22T16:28:25.543272Z",
+ "iopub.status.idle": "2023-01-22T16:28:25.551809Z",
+ "shell.execute_reply": "2023-01-22T16:28:25.550511Z",
+ "shell.execute_reply.started": "2023-01-22T16:28:25.544233Z"
+ },
+ "id": "pAzg23jU3TSo",
+ "outputId": "baeb6ea9-ca6f-4bb9-da7b-3fc16669db23"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(65974, 10)"
+ ]
+ },
+ "execution_count": 64,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "top10_iids.reshape(dists.shape[0], 10).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 65,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:28:37.182827Z",
+ "iopub.status.busy": "2023-01-22T16:28:37.182088Z",
+ "iopub.status.idle": "2023-01-22T16:28:37.190183Z",
+ "shell.execute_reply": "2023-01-22T16:28:37.188831Z",
+ "shell.execute_reply.started": "2023-01-22T16:28:37.182788Z"
+ },
+ "id": "ehH1-C-S6yE9",
+ "outputId": "5a08578e-7bc1-404a-db00-5d2114b7ad28"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(65974, 10)"
+ ]
+ },
+ "execution_count": 65,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "top10_iids.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 66,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:28:47.517537Z",
+ "iopub.status.busy": "2023-01-22T16:28:47.516704Z",
+ "iopub.status.idle": "2023-01-22T16:28:47.800629Z",
+ "shell.execute_reply": "2023-01-22T16:28:47.799272Z",
+ "shell.execute_reply.started": "2023-01-22T16:28:47.517501Z"
+ },
+ "id": "srptkYsFsk1V"
+ },
+ "outputs": [],
+ "source": [
+ "top10_iids_item = [iid_to_item_id[iid] for iid in top10_iids.reshape(-1)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 67,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:28:51.654826Z",
+ "iopub.status.busy": "2023-01-22T16:28:51.653959Z",
+ "iopub.status.idle": "2023-01-22T16:28:51.700602Z",
+ "shell.execute_reply": "2023-01-22T16:28:51.699239Z",
+ "shell.execute_reply.started": "2023-01-22T16:28:51.654791Z"
+ },
+ "id": "GWCz9zErskwn"
+ },
+ "outputs": [],
+ "source": [
+ "top10_iids_item = np.array(top10_iids_item).reshape(top10_iids.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 68,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:28:57.535876Z",
+ "iopub.status.busy": "2023-01-22T16:28:57.535194Z",
+ "iopub.status.idle": "2023-01-22T16:28:57.543046Z",
+ "shell.execute_reply": "2023-01-22T16:28:57.541704Z",
+ "shell.execute_reply.started": "2023-01-22T16:28:57.535836Z"
+ },
+ "id": "pNq_brUisknx",
+ "outputId": "f1980332-ef9e-4920-b470-675a567c815a"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(65974, 10)"
+ ]
+ },
+ "execution_count": 68,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "top10_iids_item.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:29:00.647959Z",
+ "iopub.status.busy": "2023-01-22T16:29:00.646906Z",
+ "iopub.status.idle": "2023-01-22T16:29:00.657077Z",
+ "shell.execute_reply": "2023-01-22T16:29:00.655386Z",
+ "shell.execute_reply.started": "2023-01-22T16:29:00.647919Z"
+ },
+ "id": "z6ussvRSth2h"
+ },
+ "outputs": [],
+ "source": [
+ "df_dssm = pd.DataFrame(columns = ['user_id', 'item_id'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 70,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:29:08.741226Z",
+ "iopub.status.busy": "2023-01-22T16:29:08.740780Z",
+ "iopub.status.idle": "2023-01-22T16:29:08.751118Z",
+ "shell.execute_reply": "2023-01-22T16:29:08.750073Z",
+ "shell.execute_reply.started": "2023-01-22T16:29:08.741183Z"
+ },
+ "id": "Y9XvpPzRu82h",
+ "outputId": "8397c843-1427-444e-af8d-f5c5cd19a80d"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "Empty DataFrame\n",
+ "Columns: [user_id, item_id]\n",
+ "Index: []"
+ ]
+ },
+ "execution_count": 70,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_dssm.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:29:16.894986Z",
+ "iopub.status.busy": "2023-01-22T16:29:16.894575Z",
+ "iopub.status.idle": "2023-01-22T16:29:16.955651Z",
+ "shell.execute_reply": "2023-01-22T16:29:16.954612Z",
+ "shell.execute_reply.started": "2023-01-22T16:29:16.894935Z"
+ },
+ "id": "KieINSdwvIu7"
+ },
+ "outputs": [],
+ "source": [
+ "df_dssm = pd.DataFrame({'user_id': [uid_to_user_id[uid] for uid in np.arange(top10_iids_item.shape[0])]})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 72,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:29:18.345819Z",
+ "iopub.status.busy": "2023-01-22T16:29:18.345404Z",
+ "iopub.status.idle": "2023-01-22T16:29:18.371714Z",
+ "shell.execute_reply": "2023-01-22T16:29:18.370527Z",
+ "shell.execute_reply.started": "2023-01-22T16:29:18.345785Z"
+ },
+ "id": "RSYHUj7IuzT1"
+ },
+ "outputs": [],
+ "source": [
+ "df_dssm['item_id'] = list(top10_iids_item)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:29:19.355843Z",
+ "iopub.status.busy": "2023-01-22T16:29:19.355038Z",
+ "iopub.status.idle": "2023-01-22T16:29:20.100815Z",
+ "shell.execute_reply": "2023-01-22T16:29:20.099612Z",
+ "shell.execute_reply.started": "2023-01-22T16:29:19.355801Z"
+ },
+ "id": "xdPs4HY874OZ"
+ },
+ "outputs": [],
+ "source": [
+ "df_dssm = df_dssm.explode('item_id')\n",
+ "df_dssm['rank'] = df_dssm.groupby('user_id').cumcount() + 1\n",
+ "df_dssm = df_dssm.groupby('user_id').agg({'item_id': list}).reset_index()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:29:20.104964Z",
+ "iopub.status.busy": "2023-01-22T16:29:20.104605Z",
+ "iopub.status.idle": "2023-01-22T16:29:20.117444Z",
+ "shell.execute_reply": "2023-01-22T16:29:20.115641Z",
+ "shell.execute_reply.started": "2023-01-22T16:29:20.104915Z"
+ },
+ "id": "C8kdzzf6wAuz",
+ "outputId": "df930923-8ad5-45f8-f76c-ab847237802d"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 2 | \n",
+ " [4457, 4151, 142, 9988, 4475, 4740, 9169, 5982... | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 21 | \n",
+ " [4457, 3734, 9988, 4740, 2954, 2657, 4151, 152... | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 53 | \n",
+ " [4457, 2220, 4151, 142, 4740, 15297, 2657, 134... | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 60 | \n",
+ " [4457, 4151, 142, 9988, 3734, 6443, 4740, 2954... | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 81 | \n",
+ " [4151, 4740, 2657, 4457, 15297, 281, 142, 9169... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id item_id\n",
+ "0 2 [4457, 4151, 142, 9988, 4475, 4740, 9169, 5982...\n",
+ "1 21 [4457, 3734, 9988, 4740, 2954, 2657, 4151, 152...\n",
+ "2 53 [4457, 2220, 4151, 142, 4740, 15297, 2657, 134...\n",
+ "3 60 [4457, 4151, 142, 9988, 3734, 6443, 4740, 2954...\n",
+ "4 81 [4151, 4740, 2657, 4457, 15297, 281, 142, 9169..."
+ ]
+ },
+ "execution_count": 74,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_dssm.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2023-01-22T16:29:31.350415Z",
+ "iopub.status.busy": "2023-01-22T16:29:31.349997Z",
+ "iopub.status.idle": "2023-01-22T16:29:31.715454Z",
+ "shell.execute_reply": "2023-01-22T16:29:31.714324Z",
+ "shell.execute_reply.started": "2023-01-22T16:29:31.350382Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "df_dssm.to_csv('dssm_predictions.csv', index = False)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/hw_5_recbool.ipynb b/hw_5_recbool.ipynb
new file mode 100644
index 00000000..cf414323
--- /dev/null
+++ b/hw_5_recbool.ipynb
@@ -0,0 +1 @@
+{"cells":[{"cell_type":"code","execution_count":18,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T17:56:37.034499Z","iopub.status.busy":"2023-01-22T17:56:37.034012Z","iopub.status.idle":"2023-01-22T17:56:37.042666Z","shell.execute_reply":"2023-01-22T17:56:37.041481Z","shell.execute_reply.started":"2023-01-22T17:56:37.034455Z"},"papermill":{"duration":1.244043,"end_time":"2022-11-27T16:33:29.277270","exception":false,"start_time":"2022-11-27T16:33:28.033227","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import ast\n","import json\n","import matplotlib.pyplot as plt\n","import numpy as np\n","import os\n","import pandas as pd\n","import pickle\n","\n","import warnings\n","warnings.filterwarnings('ignore')\n","\n","from collections import Counter\n","from random import randint, random\n","from scipy.sparse import coo_matrix, hstack\n","from sklearn.metrics.pairwise import euclidean_distances, cosine_distances, cosine_similarity"]},{"cell_type":"code","execution_count":20,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:33.523160Z","iopub.status.busy":"2023-01-22T18:02:33.522766Z","iopub.status.idle":"2023-01-22T18:02:36.724444Z","shell.execute_reply":"2023-01-22T18:02:36.723409Z","shell.execute_reply.started":"2023-01-22T18:02:33.523126Z"},"papermill":{"duration":6.445298,"end_time":"2022-11-27T16:33:35.747539","exception":false,"start_time":"2022-11-27T16:33:29.302241","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["interactions_df = pd.read_csv('interactions_processed_kion.csv')\n","users_df = pd.read_csv('users_processed_kion.csv')\n","items_df = pd.read_csv('items_processed_kion.csv')"]},{"cell_type":"code","execution_count":21,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:41.118088Z","iopub.status.busy":"2023-01-22T18:02:41.117711Z","iopub.status.idle":"2023-01-22T18:02:42.100146Z","shell.execute_reply":"2023-01-22T18:02:42.098848Z","shell.execute_reply.started":"2023-01-22T18:02:41.118057Z"},"papermill":{"duration":0.925082,"end_time":"2022-11-27T16:33:36.677439","exception":false,"start_time":"2022-11-27T16:33:35.752357","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["interactions_df['t_dat'] = pd.to_datetime(interactions_df['last_watch_dt'], format=\"%Y-%m-%d\")\n","interactions_df['timestamp'] = interactions_df.t_dat.values.astype(np.int64) // 10 ** 9"]},{"cell_type":"code","execution_count":22,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:42.111635Z","iopub.status.busy":"2023-01-22T18:02:42.110287Z","iopub.status.idle":"2023-01-22T18:02:42.408437Z","shell.execute_reply":"2023-01-22T18:02:42.407310Z","shell.execute_reply.started":"2023-01-22T18:02:42.111593Z"},"papermill":{"duration":0.284147,"end_time":"2022-11-27T16:33:36.966533","exception":false,"start_time":"2022-11-27T16:33:36.682386","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["df = interactions_df[['user_id', 'item_id', 'timestamp']].rename(\n"," columns={'user_id': 'user_id:token', 'item_id': 'item_id:token', 'timestamp': 'timestamp:float'})"]},{"cell_type":"code","execution_count":23,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:43.902049Z","iopub.status.busy":"2023-01-22T18:02:43.901227Z","iopub.status.idle":"2023-01-22T18:02:43.927071Z","shell.execute_reply":"2023-01-22T18:02:43.925875Z","shell.execute_reply.started":"2023-01-22T18:02:43.902007Z"},"trusted":true},"outputs":[{"data":{"text/html":["\n","\n","
\n"," \n"," \n"," | \n"," user_id:token | \n"," item_id:token | \n"," timestamp:float | \n","
\n"," \n"," \n"," \n"," | 0 | \n"," 176549 | \n"," 9506 | \n"," 1620691200 | \n","
\n"," \n"," | 1 | \n"," 699317 | \n"," 1659 | \n"," 1622246400 | \n","
\n"," \n"," | 2 | \n"," 656683 | \n"," 7107 | \n"," 1620518400 | \n","
\n"," \n"," | 3 | \n"," 864613 | \n"," 7638 | \n"," 1625443200 | \n","
\n"," \n"," | 4 | \n"," 964868 | \n"," 9506 | \n"," 1619740800 | \n","
\n"," \n"," | ... | \n"," ... | \n"," ... | \n"," ... | \n","
\n"," \n"," | 5476246 | \n"," 648596 | \n"," 12225 | \n"," 1628812800 | \n","
\n"," \n"," | 5476247 | \n"," 546862 | \n"," 9673 | \n"," 1618272000 | \n","
\n"," \n"," | 5476248 | \n"," 697262 | \n"," 15297 | \n"," 1629417600 | \n","
\n"," \n"," | 5476249 | \n"," 384202 | \n"," 16197 | \n"," 1618790400 | \n","
\n"," \n"," | 5476250 | \n"," 319709 | \n"," 4436 | \n"," 1628985600 | \n","
\n"," \n","
\n","
5476251 rows × 3 columns
\n","
"],"text/plain":[" user_id:token item_id:token timestamp:float\n","0 176549 9506 1620691200\n","1 699317 1659 1622246400\n","2 656683 7107 1620518400\n","3 864613 7638 1625443200\n","4 964868 9506 1619740800\n","... ... ... ...\n","5476246 648596 12225 1628812800\n","5476247 546862 9673 1618272000\n","5476248 697262 15297 1629417600\n","5476249 384202 16197 1618790400\n","5476250 319709 4436 1628985600\n","\n","[5476251 rows x 3 columns]"]},"execution_count":23,"metadata":{},"output_type":"execute_result"}],"source":["df"]},{"cell_type":"code","execution_count":25,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:47.909321Z","iopub.status.busy":"2023-01-22T18:02:47.908208Z","iopub.status.idle":"2023-01-22T18:02:54.589560Z","shell.execute_reply":"2023-01-22T18:02:54.588499Z","shell.execute_reply.started":"2023-01-22T18:02:47.909281Z"},"papermill":{"duration":7.834652,"end_time":"2022-11-27T16:33:45.906924","exception":false,"start_time":"2022-11-27T16:33:38.072272","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["df.to_csv('recbox_data/recbox_data.inter', index=False, sep='\\t')"]},{"cell_type":"code","execution_count":28,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:02:54.592386Z","iopub.status.busy":"2023-01-22T18:02:54.591996Z","iopub.status.idle":"2023-01-22T18:02:55.527789Z","shell.execute_reply":"2023-01-22T18:02:55.526787Z","shell.execute_reply.started":"2023-01-22T18:02:54.592332Z"},"papermill":{"duration":3.067001,"end_time":"2022-11-27T16:34:04.068318","exception":false,"start_time":"2022-11-27T16:34:01.001317","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import logging\n","from logging import getLogger\n","from recbole.config import Config\n","from recbole.data import create_dataset, data_preparation\n","from recbole.model.sequential_recommender import GRU4Rec, Caser\n","from recbole.trainer import Trainer\n","from recbole.utils import init_seed, init_logger\n","from recbole.quick_start import run_recbole"]},{"cell_type":"code","execution_count":29,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:04:51.862473Z","iopub.status.busy":"2023-01-22T18:04:51.862041Z","iopub.status.idle":"2023-01-22T18:04:51.900690Z","shell.execute_reply":"2023-01-22T18:04:51.899741Z","shell.execute_reply.started":"2023-01-22T18:04:51.862435Z"},"papermill":{"duration":0.145622,"end_time":"2022-11-27T16:34:04.220395","exception":false,"start_time":"2022-11-27T16:34:04.074773","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["parameter_dict = {\n"," 'data_path': '',\n"," 'USER_ID_FIELD': 'user_id',\n"," 'ITEM_ID_FIELD': 'item_id',\n"," 'TIME_FIELD': 'timestamp',\n"," 'device': 'GPU',\n"," 'user_inter_num_interval': \"[40,inf)\",\n"," 'item_inter_num_interval': \"[40,inf)\",\n"," 'load_col': {'inter': ['user_id', 'item_id', 'timestamp']},\n"," 'neg_sampling': None,\n"," 'epochs': 10,\n"," 'verbose': -1,\n"," 'show_progress' : False,\n"," 'eval_args': {\n"," 'split': {'RS': [9, 0, 1]},\n"," 'group_by': 'user',\n"," 'order': 'TO',\n"," 'mode': 'full'}\n","}\n","config = Config(model='MultiVAE', dataset='recbox_data', config_dict=parameter_dict)\n","\n","# init random seed\n","init_seed(config['seed'], config['reproducibility'])\n","\n","# logger initialization\n","init_logger(config)\n","logger = getLogger()\n","# Create handlers\n","c_handler = logging.StreamHandler()\n","c_handler.setLevel(logging.INFO)\n","logger.addHandler(c_handler)\n","\n","# write config info into log\n","# logger.info(config)"]},{"cell_type":"code","execution_count":30,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:04:55.538201Z","iopub.status.busy":"2023-01-22T18:04:55.537818Z","iopub.status.idle":"2023-01-22T18:05:32.322220Z","shell.execute_reply":"2023-01-22T18:05:32.321423Z","shell.execute_reply.started":"2023-01-22T18:04:55.538170Z"},"papermill":{"duration":42.583583,"end_time":"2022-11-27T16:34:46.811041","exception":false,"start_time":"2022-11-27T16:34:04.227458","status":"completed"},"tags":[],"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["11 Dec 11:56 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n"]}],"source":["dataset = create_dataset(config)\n","logger.info(dataset)"]},{"cell_type":"code","execution_count":31,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:05:32.324208Z","iopub.status.busy":"2023-01-22T18:05:32.323852Z","iopub.status.idle":"2023-01-22T18:05:34.256086Z","shell.execute_reply":"2023-01-22T18:05:34.255320Z","shell.execute_reply.started":"2023-01-22T18:05:32.324171Z"},"papermill":{"duration":2.241551,"end_time":"2022-11-27T16:34:49.059852","exception":false,"start_time":"2022-11-27T16:34:46.818301","status":"completed"},"tags":[],"trusted":true},"outputs":[{"name":"stderr","output_type":"stream","text":["11 Dec 11:56 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 11:56 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n"]}],"source":["# dataset splitting\n","train_data, valid_data, test_data = data_preparation(config, dataset)"]},{"cell_type":"code","execution_count":32,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:05:34.257762Z","iopub.status.busy":"2023-01-22T18:05:34.257174Z","iopub.status.idle":"2023-01-22T18:05:34.262360Z","shell.execute_reply":"2023-01-22T18:05:34.261553Z","shell.execute_reply.started":"2023-01-22T18:05:34.257723Z"},"papermill":{"duration":0.01694,"end_time":"2022-11-27T16:34:49.085164","exception":false,"start_time":"2022-11-27T16:34:49.068224","status":"completed"},"tags":[],"trusted":true},"outputs":[],"source":["import time"]},{"cell_type":"markdown","metadata":{},"source":["### Использование различных архитектур"]},{"cell_type":"code","execution_count":33,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:05:41.096708Z","iopub.status.busy":"2023-01-22T18:05:41.096214Z","iopub.status.idle":"2023-01-22T18:11:38.568018Z","shell.execute_reply":"2023-01-22T18:11:38.567070Z","shell.execute_reply.started":"2023-01-22T18:05:41.096667Z"},"papermill":{"duration":27259.293886,"end_time":"2022-11-28T00:09:08.387403","exception":false,"start_time":"2022-11-27T16:34:49.093517","status":"completed"},"tags":[],"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["running LightGCN...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 11:56 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 11:56 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","n_layers = 2\n","reg_weight = 1e-05\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","n_layers = 2\n","reg_weight = 1e-05\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 11:58 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 11:58 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 11:58 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 11:58 INFO LightGCN(\n"," (user_embedding): Embedding(13355, 64)\n"," (item_embedding): Embedding(3294, 64)\n"," (mf_loss): BPRLoss()\n"," (reg_loss): EmbLoss()\n",")\n","Trainable parameters: 1065536\n","LightGCN(\n"," (user_embedding): Embedding(13355, 64)\n"," (item_embedding): Embedding(3294, 64)\n"," (mf_loss): BPRLoss()\n"," (reg_loss): EmbLoss()\n",")\n","Trainable parameters: 1065536\n","11 Dec 11:58 INFO FLOPs: 0.0\n","FLOPs: 0.0\n","11 Dec 12:00 INFO epoch 0 training [time: 96.30s, train loss: 201.8552]\n","epoch 0 training [time: 96.30s, train loss: 201.8552]\n","11 Dec 12:00 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:01 INFO epoch 1 training [time: 91.67s, train loss: 166.0586]\n","epoch 1 training [time: 91.67s, train loss: 166.0586]\n","11 Dec 12:01 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:03 INFO epoch 2 training [time: 106.54s, train loss: 156.0221]\n","epoch 2 training [time: 106.54s, train loss: 156.0221]\n","11 Dec 12:03 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:05 INFO epoch 3 training [time: 125.49s, train loss: 149.5909]\n","epoch 3 training [time: 125.49s, train loss: 149.5909]\n","11 Dec 12:05 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:07 INFO epoch 4 training [time: 123.82s, train loss: 146.1851]\n","epoch 4 training [time: 123.82s, train loss: 146.1851]\n","11 Dec 12:07 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:10 INFO epoch 5 training [time: 135.29s, train loss: 143.7300]\n","epoch 5 training [time: 135.29s, train loss: 143.7300]\n","11 Dec 12:10 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:12 INFO epoch 6 training [time: 152.14s, train loss: 141.2300]\n","epoch 6 training [time: 152.14s, train loss: 141.2300]\n","11 Dec 12:12 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:14 INFO epoch 7 training [time: 114.99s, train loss: 137.4871]\n","epoch 7 training [time: 114.99s, train loss: 137.4871]\n","11 Dec 12:14 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:16 INFO epoch 8 training [time: 128.70s, train loss: 133.3195]\n","epoch 8 training [time: 128.70s, train loss: 133.3195]\n","11 Dec 12:16 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:18 INFO epoch 9 training [time: 126.52s, train loss: 129.4056]\n","epoch 9 training [time: 126.52s, train loss: 129.4056]\n","11 Dec 12:18 INFO Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Saving current: saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:18 INFO Loading model structure and parameters from saved/LightGCN-Dec-11-2023_11-58-43.pth\n","Loading model structure and parameters from saved/LightGCN-Dec-11-2023_11-58-43.pth\n","11 Dec 12:18 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 48.50 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.07 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 48.50 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.07 G/8.00 G |\n","+-------------+---------------+\n","11 Dec 12:18 INFO best valid : None\n","best valid : None\n","11 Dec 12:18 INFO test result: OrderedDict([('recall@10', 0.0792), ('mrr@10', 0.1685), ('ndcg@10', 0.0795), ('hit@10', 0.3385), ('precision@10', 0.0441)])\n","test result: OrderedDict([('recall@10', 0.0792), ('mrr@10', 0.1685), ('ndcg@10', 0.0795), ('hit@10', 0.3385), ('precision@10', 0.0441)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 21.95 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0792), ('mrr@10', 0.1685), ('ndcg@10', 0.0795), ('hit@10', 0.3385), ('precision@10', 0.0441)])}\n","running MultiVAE...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 12:18 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 12:18 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","mlp_hidden_size = [600]\n","latent_dimension = 128\n","dropout_prob = 0.5\n","anneal_cap = 0.2\n","total_anneal_steps = 200000\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","mlp_hidden_size = [600]\n","latent_dimension = 128\n","dropout_prob = 0.5\n","anneal_cap = 0.2\n","total_anneal_steps = 200000\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 12:21 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 12:21 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 12:21 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 12:21 WARNING Max value of user's history interaction records has reached 20.9471766848816% of the total.\n","Max value of user's history interaction records has reached 20.9471766848816% of the total.\n","11 Dec 12:21 INFO MultiVAE(\n"," (encoder): Sequential(\n"," (0): Linear(in_features=3294, out_features=600, bias=True)\n"," (1): Tanh()\n"," (2): Linear(in_features=600, out_features=128, bias=True)\n"," )\n"," (decoder): Sequential(\n"," (0): Linear(in_features=64, out_features=600, bias=True)\n"," (1): Tanh()\n"," (2): Linear(in_features=600, out_features=3294, bias=True)\n"," )\n",")\n","Trainable parameters: 4072622\n","MultiVAE(\n"," (encoder): Sequential(\n"," (0): Linear(in_features=3294, out_features=600, bias=True)\n"," (1): Tanh()\n"," (2): Linear(in_features=600, out_features=128, bias=True)\n"," )\n"," (decoder): Sequential(\n"," (0): Linear(in_features=64, out_features=600, bias=True)\n"," (1): Tanh()\n"," (2): Linear(in_features=600, out_features=3294, bias=True)\n"," )\n",")\n","Trainable parameters: 4072622\n","11 Dec 12:21 INFO FLOPs: 4068000.0\n","FLOPs: 4068000.0\n","11 Dec 12:21 INFO epoch 0 training [time: 2.16s, train loss: 3249.3142]\n","epoch 0 training [time: 2.16s, train loss: 3249.3142]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 1 training [time: 1.96s, train loss: 3098.4010]\n","epoch 1 training [time: 1.96s, train loss: 3098.4010]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 2 training [time: 1.97s, train loss: 3045.1938]\n","epoch 2 training [time: 1.97s, train loss: 3045.1938]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 3 training [time: 2.02s, train loss: 3008.0520]\n","epoch 3 training [time: 2.02s, train loss: 3008.0520]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 4 training [time: 2.58s, train loss: 2949.4743]\n","epoch 4 training [time: 2.58s, train loss: 2949.4743]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:21 INFO epoch 5 training [time: 2.14s, train loss: 2917.6707]\n","epoch 5 training [time: 2.14s, train loss: 2917.6707]\n","11 Dec 12:21 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO epoch 6 training [time: 2.28s, train loss: 2897.4954]\n","epoch 6 training [time: 2.28s, train loss: 2897.4954]\n","11 Dec 12:22 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO epoch 7 training [time: 2.03s, train loss: 2885.5641]\n","epoch 7 training [time: 2.03s, train loss: 2885.5641]\n","11 Dec 12:22 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO epoch 8 training [time: 2.38s, train loss: 2871.9012]\n","epoch 8 training [time: 2.38s, train loss: 2871.9012]\n","11 Dec 12:22 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO epoch 9 training [time: 2.46s, train loss: 2851.2055]\n","epoch 9 training [time: 2.46s, train loss: 2851.2055]\n","11 Dec 12:22 INFO Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Saving current: saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO Loading model structure and parameters from saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","Loading model structure and parameters from saved/MultiVAE-Dec-11-2023_12-21-45.pth\n","11 Dec 12:22 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 74.20 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.08 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 74.20 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.08 G/8.00 G |\n","+-------------+---------------+\n","11 Dec 12:22 INFO best valid : None\n","best valid : None\n","11 Dec 12:22 INFO test result: OrderedDict([('recall@10', 0.0839), ('mrr@10', 0.1687), ('ndcg@10', 0.0823), ('hit@10', 0.3494), ('precision@10', 0.0465)])\n","test result: OrderedDict([('recall@10', 0.0839), ('mrr@10', 0.1687), ('ndcg@10', 0.0823), ('hit@10', 0.3494), ('precision@10', 0.0465)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 3.87 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0839), ('mrr@10', 0.1687), ('ndcg@10', 0.0823), ('hit@10', 0.3494), ('precision@10', 0.0465)])}\n","running RecVAE...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 12:22 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 12:22 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","hidden_dimension = 600\n","latent_dimension = 200\n","dropout_prob = 0.5\n","beta = 0.2\n","gamma = 0.005\n","mixture_weights = [0.15, 0.75, 0.1]\n","n_enc_epochs = 3\n","n_dec_epochs = 1\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = False\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","hidden_dimension = 600\n","latent_dimension = 200\n","dropout_prob = 0.5\n","beta = 0.2\n","gamma = 0.005\n","mixture_weights = [0.15, 0.75, 0.1]\n","n_enc_epochs = 3\n","n_dec_epochs = 1\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.GENERAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.PAIRWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 12:25 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 12:25 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 12:25 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 12:25 WARNING Max value of user's history interaction records has reached 20.9471766848816% of the total.\n","Max value of user's history interaction records has reached 20.9471766848816% of the total.\n","11 Dec 12:25 INFO RecVAE(\n"," (encoder): Encoder(\n"," (fc1): Linear(in_features=3294, out_features=600, bias=True)\n"," (ln1): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc2): Linear(in_features=600, out_features=600, bias=True)\n"," (ln2): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc3): Linear(in_features=600, out_features=600, bias=True)\n"," (ln3): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc4): Linear(in_features=600, out_features=600, bias=True)\n"," (ln4): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc5): Linear(in_features=600, out_features=600, bias=True)\n"," (ln5): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc_mu): Linear(in_features=600, out_features=200, bias=True)\n"," (fc_logvar): Linear(in_features=600, out_features=200, bias=True)\n"," )\n"," (prior): CompositePrior(\n"," (encoder_old): Encoder(\n"," (fc1): Linear(in_features=3294, out_features=600, bias=True)\n"," (ln1): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc2): Linear(in_features=600, out_features=600, bias=True)\n"," (ln2): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc3): Linear(in_features=600, out_features=600, bias=True)\n"," (ln3): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc4): Linear(in_features=600, out_features=600, bias=True)\n"," (ln4): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc5): Linear(in_features=600, out_features=600, bias=True)\n"," (ln5): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc_mu): Linear(in_features=600, out_features=200, bias=True)\n"," (fc_logvar): Linear(in_features=600, out_features=200, bias=True)\n"," )\n"," )\n"," (decoder): Linear(in_features=200, out_features=3294, bias=True)\n",")\n","Trainable parameters: 4327894\n","RecVAE(\n"," (encoder): Encoder(\n"," (fc1): Linear(in_features=3294, out_features=600, bias=True)\n"," (ln1): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc2): Linear(in_features=600, out_features=600, bias=True)\n"," (ln2): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc3): Linear(in_features=600, out_features=600, bias=True)\n"," (ln3): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc4): Linear(in_features=600, out_features=600, bias=True)\n"," (ln4): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc5): Linear(in_features=600, out_features=600, bias=True)\n"," (ln5): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc_mu): Linear(in_features=600, out_features=200, bias=True)\n"," (fc_logvar): Linear(in_features=600, out_features=200, bias=True)\n"," )\n"," (prior): CompositePrior(\n"," (encoder_old): Encoder(\n"," (fc1): Linear(in_features=3294, out_features=600, bias=True)\n"," (ln1): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc2): Linear(in_features=600, out_features=600, bias=True)\n"," (ln2): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc3): Linear(in_features=600, out_features=600, bias=True)\n"," (ln3): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc4): Linear(in_features=600, out_features=600, bias=True)\n"," (ln4): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc5): Linear(in_features=600, out_features=600, bias=True)\n"," (ln5): LayerNorm((600,), eps=0.1, elementwise_affine=True)\n"," (fc_mu): Linear(in_features=600, out_features=200, bias=True)\n"," (fc_logvar): Linear(in_features=600, out_features=200, bias=True)\n"," )\n"," )\n"," (decoder): Linear(in_features=200, out_features=3294, bias=True)\n",")\n","Trainable parameters: 4327894\n","11 Dec 12:25 INFO FLOPs: 4321200.0\n","FLOPs: 4321200.0\n","11 Dec 12:25 INFO epoch 0 training [time: 23.87s, train loss: 2354.4009]\n","epoch 0 training [time: 23.87s, train loss: 2354.4009]\n","11 Dec 12:25 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:26 INFO epoch 1 training [time: 26.41s, train loss: 2247.2854]\n","epoch 1 training [time: 26.41s, train loss: 2247.2854]\n","11 Dec 12:26 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:26 INFO epoch 2 training [time: 26.19s, train loss: 2184.4206]\n","epoch 2 training [time: 26.19s, train loss: 2184.4206]\n","11 Dec 12:26 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:27 INFO epoch 3 training [time: 28.26s, train loss: 2147.9836]\n","epoch 3 training [time: 28.26s, train loss: 2147.9836]\n","11 Dec 12:27 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:27 INFO epoch 4 training [time: 25.59s, train loss: 2108.6837]\n","epoch 4 training [time: 25.59s, train loss: 2108.6837]\n","11 Dec 12:27 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:27 INFO epoch 5 training [time: 19.99s, train loss: 2073.2995]\n","epoch 5 training [time: 19.99s, train loss: 2073.2995]\n","11 Dec 12:27 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO epoch 6 training [time: 23.58s, train loss: 2043.1616]\n","epoch 6 training [time: 23.58s, train loss: 2043.1616]\n","11 Dec 12:28 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO epoch 7 training [time: 24.14s, train loss: 2013.9314]\n","epoch 7 training [time: 24.14s, train loss: 2013.9314]\n","11 Dec 12:28 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO epoch 8 training [time: 12.42s, train loss: 1998.7426]\n","epoch 8 training [time: 12.42s, train loss: 1998.7426]\n","11 Dec 12:28 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO epoch 9 training [time: 10.69s, train loss: 1973.2974]\n","epoch 9 training [time: 10.69s, train loss: 1973.2974]\n","11 Dec 12:28 INFO Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Saving current: saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:28 INFO Loading model structure and parameters from saved/RecVAE-Dec-11-2023_12-25-16.pth\n","Loading model structure and parameters from saved/RecVAE-Dec-11-2023_12-25-16.pth\n","11 Dec 12:29 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 50.10 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.09 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 50.10 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.09 G/8.00 G |\n","+-------------+---------------+\n","11 Dec 12:29 INFO best valid : None\n","best valid : None\n","11 Dec 12:29 INFO test result: OrderedDict([('recall@10', 0.0844), ('mrr@10', 0.1662), ('ndcg@10', 0.0816), ('hit@10', 0.3519), ('precision@10', 0.0468)])\n","test result: OrderedDict([('recall@10', 0.0844), ('mrr@10', 0.1662), ('ndcg@10', 0.0816), ('hit@10', 0.3519), ('precision@10', 0.0468)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 6.70 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0844), ('mrr@10', 0.1662), ('ndcg@10', 0.0816), ('hit@10', 0.3519), ('precision@10', 0.0468)])}\n","CPU times: user 25min 39s, sys: 9min 10s, total: 34min 49s\n","Wall time: 32min 31s\n"]}],"source":["%%time\n","model_list = [ \"LightGCN\", \"MultiVAE\", \"RecVAE\"] \n","\n","for model_name in model_list:\n"," print(f\"running {model_name}...\")\n"," start = time.time()\n"," result = run_recbole(model=model_name, dataset = 'recbox_data',config_dict = parameter_dict)\n"," t = time.time() - start\n"," print(f\"It took {t/60:.2f} mins\")\n"," print(result)"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Collecting kmeans-pytorch\n"," Downloading kmeans_pytorch-0.3-py3-none-any.whl (4.4 kB)\n","Installing collected packages: kmeans-pytorch\n","Successfully installed kmeans-pytorch-0.3\n","Note: you may need to restart the kernel to use updated packages.\n"]}],"source":["%pip install kmeans-pytorch"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[],"source":["from kmeans_pytorch import kmeans"]},{"cell_type":"code","execution_count":37,"metadata":{"execution":{"iopub.execute_input":"2023-01-22T18:14:48.482175Z","iopub.status.busy":"2023-01-22T18:14:48.481796Z","iopub.status.idle":"2023-01-22T19:32:27.636297Z","shell.execute_reply":"2023-01-22T19:32:27.635371Z","shell.execute_reply.started":"2023-01-22T18:14:48.482143Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["running CORE...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 13:40 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 13:40 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","inner_size = 256\n","n_layers = 2\n","n_heads = 2\n","hidden_dropout_prob = 0.5\n","attn_dropout_prob = 0.5\n","hidden_act = gelu\n","layer_norm_eps = 1e-12\n","initializer_range = 0.02\n","loss_type = CE\n","dnn_type = trm\n","sess_dropout = 0.2\n","item_dropout = 0.2\n","temperature = 0.07\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","inner_size = 256\n","n_layers = 2\n","n_heads = 2\n","hidden_dropout_prob = 0.5\n","attn_dropout_prob = 0.5\n","hidden_act = gelu\n","layer_norm_eps = 1e-12\n","initializer_range = 0.02\n","loss_type = CE\n","dnn_type = trm\n","sess_dropout = 0.2\n","item_dropout = 0.2\n","temperature = 0.07\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 13:41 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 13:42 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 13:42 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 13:42 INFO CORE(\n"," (sess_dropout): Dropout(p=0.2, inplace=False)\n"," (item_dropout): Dropout(p=0.2, inplace=False)\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (net): TransNet(\n"," (position_embedding): Embedding(50, 64)\n"," (trm_encoder): TransformerEncoder(\n"," (layer): ModuleList(\n"," (0-1): 2 x TransformerLayer(\n"," (multi_head_attention): MultiHeadAttention(\n"," (query): Linear(in_features=64, out_features=64, bias=True)\n"," (key): Linear(in_features=64, out_features=64, bias=True)\n"," (value): Linear(in_features=64, out_features=64, bias=True)\n"," (softmax): Softmax(dim=-1)\n"," (attn_dropout): Dropout(p=0.5, inplace=False)\n"," (dense): Linear(in_features=64, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (out_dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," (feed_forward): FeedForward(\n"," (dense_1): Linear(in_features=64, out_features=256, bias=True)\n"," (dense_2): Linear(in_features=256, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," (fn): Linear(in_features=64, out_features=1, bias=True)\n"," )\n"," (loss_fct): CrossEntropyLoss()\n",")\n","Trainable parameters: 314177\n","CORE(\n"," (sess_dropout): Dropout(p=0.2, inplace=False)\n"," (item_dropout): Dropout(p=0.2, inplace=False)\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (net): TransNet(\n"," (position_embedding): Embedding(50, 64)\n"," (trm_encoder): TransformerEncoder(\n"," (layer): ModuleList(\n"," (0-1): 2 x TransformerLayer(\n"," (multi_head_attention): MultiHeadAttention(\n"," (query): Linear(in_features=64, out_features=64, bias=True)\n"," (key): Linear(in_features=64, out_features=64, bias=True)\n"," (value): Linear(in_features=64, out_features=64, bias=True)\n"," (softmax): Softmax(dim=-1)\n"," (attn_dropout): Dropout(p=0.5, inplace=False)\n"," (dense): Linear(in_features=64, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (out_dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," (feed_forward): FeedForward(\n"," (dense_1): Linear(in_features=64, out_features=256, bias=True)\n"," (dense_2): Linear(in_features=256, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," (fn): Linear(in_features=64, out_features=1, bias=True)\n"," )\n"," (loss_fct): CrossEntropyLoss()\n",")\n","Trainable parameters: 314177\n","11 Dec 13:42 INFO FLOPs: 4986664.0\n","FLOPs: 4986664.0\n","11 Dec 14:56 INFO epoch 0 training [time: 4465.02s, train loss: 3014.5362]\n","epoch 0 training [time: 4465.02s, train loss: 3014.5362]\n","11 Dec 14:56 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 15:22 INFO epoch 1 training [time: 1551.87s, train loss: 2695.9846]\n","epoch 1 training [time: 1551.87s, train loss: 2695.9846]\n","11 Dec 15:22 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 15:51 INFO epoch 2 training [time: 1725.13s, train loss: 2613.6121]\n","epoch 2 training [time: 1725.13s, train loss: 2613.6121]\n","11 Dec 15:51 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 16:20 INFO epoch 3 training [time: 1788.98s, train loss: 2579.6137]\n","epoch 3 training [time: 1788.98s, train loss: 2579.6137]\n","11 Dec 16:20 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 17:53 INFO epoch 4 training [time: 5548.65s, train loss: 2562.4357]\n","epoch 4 training [time: 5548.65s, train loss: 2562.4357]\n","11 Dec 17:53 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 18:23 INFO epoch 5 training [time: 1835.07s, train loss: 2551.7563]\n","epoch 5 training [time: 1835.07s, train loss: 2551.7563]\n","11 Dec 18:23 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 18:52 INFO epoch 6 training [time: 1691.97s, train loss: 2545.4003]\n","epoch 6 training [time: 1691.97s, train loss: 2545.4003]\n","11 Dec 18:52 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 19:23 INFO epoch 7 training [time: 1858.18s, train loss: 2540.3678]\n","epoch 7 training [time: 1858.18s, train loss: 2540.3678]\n","11 Dec 19:23 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 20:06 INFO epoch 8 training [time: 2614.77s, train loss: 2537.1684]\n","epoch 8 training [time: 2614.77s, train loss: 2537.1684]\n","11 Dec 20:06 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 21:22 INFO epoch 9 training [time: 4561.69s, train loss: 2534.2584]\n","epoch 9 training [time: 4561.69s, train loss: 2534.2584]\n","11 Dec 21:22 INFO Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","Saving current: saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 21:22 INFO Loading model structure and parameters from saved/CORE-Dec-11-2023_13-42-03.pth\n","Loading model structure and parameters from saved/CORE-Dec-11-2023_13-42-03.pth\n","11 Dec 21:23 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 46.10 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.77 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 46.10 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.77 G/8.00 G |\n","+-------------+---------------+\n","11 Dec 21:23 INFO best valid : None\n","best valid : None\n","11 Dec 21:23 INFO test result: OrderedDict([('recall@10', 0.0921), ('mrr@10', 0.0297), ('ndcg@10', 0.044), ('hit@10', 0.0921), ('precision@10', 0.0092)])\n","test result: OrderedDict([('recall@10', 0.0921), ('mrr@10', 0.0297), ('ndcg@10', 0.044), ('hit@10', 0.0921), ('precision@10', 0.0092)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 463.62 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0921), ('mrr@10', 0.0297), ('ndcg@10', 0.044), ('hit@10', 0.0921), ('precision@10', 0.0092)])}\n","running LightSANs...\n"]},{"name":"stderr","output_type":"stream","text":["11 Dec 21:23 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","11 Dec 21:23 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","k_interests = 5\n","n_layers = 2\n","n_heads = 2\n","hidden_size = 64\n","inner_size = 256\n","hidden_dropout_prob = 0.5\n","attn_dropout_prob = 0.5\n","hidden_act = gelu\n","layer_norm_eps = 1e-12\n","initializer_range = 0.02\n","loss_type = CE\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","k_interests = 5\n","n_layers = 2\n","n_heads = 2\n","hidden_size = 64\n","inner_size = 256\n","hidden_dropout_prob = 0.5\n","attn_dropout_prob = 0.5\n","hidden_act = gelu\n","layer_norm_eps = 1e-12\n","initializer_range = 0.02\n","loss_type = CE\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","11 Dec 21:31 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","11 Dec 21:31 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","11 Dec 21:31 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","11 Dec 21:31 INFO LightSANs(\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (position_embedding): Embedding(50, 64)\n"," (trm_encoder): LightTransformerEncoder(\n"," (layer): ModuleList(\n"," (0-1): 2 x LightTransformerLayer(\n"," (multi_head_attention): LightMultiHeadAttention(\n"," (query): Linear(in_features=64, out_features=64, bias=True)\n"," (key): Linear(in_features=64, out_features=64, bias=True)\n"," (value): Linear(in_features=64, out_features=64, bias=True)\n"," (attpooling_key): ItemToInterestAggregation()\n"," (attpooling_value): ItemToInterestAggregation()\n"," (pos_q_linear): Linear(in_features=64, out_features=64, bias=True)\n"," (pos_k_linear): Linear(in_features=64, out_features=64, bias=True)\n"," (pos_ln): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (attn_dropout): Dropout(p=0.5, inplace=False)\n"," (dense): Linear(in_features=64, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (out_dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," (feed_forward): FeedForward(\n"," (dense_1): Linear(in_features=64, out_features=256, bias=True)\n"," (dense_2): Linear(in_features=256, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," (loss_fct): CrossEntropyLoss()\n",")\n","Trainable parameters: 332288\n","LightSANs(\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (position_embedding): Embedding(50, 64)\n"," (trm_encoder): LightTransformerEncoder(\n"," (layer): ModuleList(\n"," (0-1): 2 x LightTransformerLayer(\n"," (multi_head_attention): LightMultiHeadAttention(\n"," (query): Linear(in_features=64, out_features=64, bias=True)\n"," (key): Linear(in_features=64, out_features=64, bias=True)\n"," (value): Linear(in_features=64, out_features=64, bias=True)\n"," (attpooling_key): ItemToInterestAggregation()\n"," (attpooling_value): ItemToInterestAggregation()\n"," (pos_q_linear): Linear(in_features=64, out_features=64, bias=True)\n"," (pos_k_linear): Linear(in_features=64, out_features=64, bias=True)\n"," (pos_ln): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (attn_dropout): Dropout(p=0.5, inplace=False)\n"," (dense): Linear(in_features=64, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (out_dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," (feed_forward): FeedForward(\n"," (dense_1): Linear(in_features=64, out_features=256, bias=True)\n"," (dense_2): Linear(in_features=256, out_features=64, bias=True)\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," )\n"," )\n"," )\n"," )\n"," (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)\n"," (dropout): Dropout(p=0.5, inplace=False)\n"," (loss_fct): CrossEntropyLoss()\n",")\n","Trainable parameters: 332288\n","11 Dec 21:31 INFO FLOPs: 5785664.0\n","FLOPs: 5785664.0\n","11 Dec 22:10 INFO epoch 0 training [time: 2297.51s, train loss: 2745.5162]\n","epoch 0 training [time: 2297.51s, train loss: 2745.5162]\n","11 Dec 22:10 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","11 Dec 22:35 INFO epoch 1 training [time: 1523.42s, train loss: 2594.7601]\n","epoch 1 training [time: 1523.42s, train loss: 2594.7601]\n","11 Dec 22:35 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","11 Dec 23:00 INFO epoch 2 training [time: 1474.23s, train loss: 2552.7842]\n","epoch 2 training [time: 1474.23s, train loss: 2552.7842]\n","11 Dec 23:00 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","11 Dec 23:24 INFO epoch 3 training [time: 1459.72s, train loss: 2529.9439]\n","epoch 3 training [time: 1459.72s, train loss: 2529.9439]\n","11 Dec 23:24 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","11 Dec 23:47 INFO epoch 4 training [time: 1402.02s, train loss: 2516.8341]\n","epoch 4 training [time: 1402.02s, train loss: 2516.8341]\n","11 Dec 23:47 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 00:44 INFO epoch 5 training [time: 3408.17s, train loss: 2508.0956]\n","epoch 5 training [time: 3408.17s, train loss: 2508.0956]\n","12 Dec 00:44 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 01:06 INFO epoch 6 training [time: 1326.43s, train loss: 2501.2301]\n","epoch 6 training [time: 1326.43s, train loss: 2501.2301]\n","12 Dec 01:06 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 02:02 INFO epoch 7 training [time: 3351.58s, train loss: 2496.2055]\n","epoch 7 training [time: 3351.58s, train loss: 2496.2055]\n","12 Dec 02:02 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 06:08 INFO epoch 8 training [time: 14777.87s, train loss: 2491.3191]\n","epoch 8 training [time: 14777.87s, train loss: 2491.3191]\n","12 Dec 06:08 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 08:40 INFO epoch 9 training [time: 9116.53s, train loss: 2487.0272]\n","epoch 9 training [time: 9116.53s, train loss: 2487.0272]\n","12 Dec 08:40 INFO Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Saving current: saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 08:40 INFO Loading model structure and parameters from saved/LightSANs-Dec-11-2023_21-31-57.pth\n","Loading model structure and parameters from saved/LightSANs-Dec-11-2023_21-31-57.pth\n","12 Dec 08:41 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 30.70 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.82 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 30.70 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.82 G/8.00 G |\n","+-------------+---------------+\n","12 Dec 08:41 INFO best valid : None\n","best valid : None\n","12 Dec 08:41 INFO test result: OrderedDict([('recall@10', 0.1029), ('mrr@10', 0.0358), ('ndcg@10', 0.0513), ('hit@10', 0.1029), ('precision@10', 0.0103)])\n","test result: OrderedDict([('recall@10', 0.1029), ('mrr@10', 0.0358), ('ndcg@10', 0.0513), ('hit@10', 0.1029), ('precision@10', 0.0103)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 677.87 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.1029), ('mrr@10', 0.0358), ('ndcg@10', 0.0513), ('hit@10', 0.1029), ('precision@10', 0.0103)])}\n","running NextItNet...\n"]},{"name":"stderr","output_type":"stream","text":["12 Dec 08:41 INFO ['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","['/Users/annapikuleva/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py', '--f=/Users/annapikuleva/Library/Jupyter/runtime/kernel-v2-3832937JAU6uqtVOE.json']\n","12 Dec 08:41 INFO \n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","kernel_size = 3\n","block_num = 5\n","dilations = [1, 4]\n","reg_weight = 1e-05\n","loss_type = CE\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","\n","General Hyper Parameters:\n","gpu_id = 0\n","use_gpu = True\n","seed = 2020\n","state = INFO\n","reproducibility = True\n","data_path = recbox_data\n","checkpoint_dir = saved\n","show_progress = False\n","save_dataset = False\n","dataset_save_path = None\n","save_dataloaders = False\n","dataloaders_save_path = None\n","log_wandb = False\n","\n","Training Hyper Parameters:\n","epochs = 10\n","train_batch_size = 2048\n","learner = adam\n","learning_rate = 0.001\n","train_neg_sample_args = {'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}\n","eval_step = 1\n","stopping_step = 10\n","clip_grad_norm = None\n","weight_decay = 0.0\n","loss_decimal_place = 4\n","\n","Evaluation Hyper Parameters:\n","eval_args = {'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}\n","repeatable = True\n","metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision']\n","topk = [10]\n","valid_metric = MRR@10\n","valid_metric_bigger = True\n","eval_batch_size = 4096\n","metric_decimal_place = 4\n","\n","Dataset Hyper Parameters:\n","field_separator = \t\n","seq_separator = \n","USER_ID_FIELD = user_id\n","ITEM_ID_FIELD = item_id\n","RATING_FIELD = rating\n","TIME_FIELD = timestamp\n","seq_len = None\n","LABEL_FIELD = label\n","threshold = None\n","NEG_PREFIX = neg_\n","load_col = {'inter': ['user_id', 'item_id', 'timestamp']}\n","unload_col = None\n","unused_col = None\n","additional_feat_suffix = None\n","rm_dup_inter = None\n","val_interval = None\n","filter_inter_by_user_or_item = True\n","user_inter_num_interval = [40,inf)\n","item_inter_num_interval = [40,inf)\n","alias_of_user_id = None\n","alias_of_item_id = None\n","alias_of_entity_id = None\n","alias_of_relation_id = None\n","preload_weight = None\n","normalize_field = None\n","normalize_all = None\n","ITEM_LIST_LENGTH_FIELD = item_length\n","LIST_SUFFIX = _list\n","MAX_ITEM_LIST_LENGTH = 50\n","POSITION_FIELD = position_id\n","HEAD_ENTITY_ID_FIELD = head_id\n","TAIL_ENTITY_ID_FIELD = tail_id\n","RELATION_ID_FIELD = relation_id\n","ENTITY_ID_FIELD = entity_id\n","benchmark_filename = None\n","\n","Other Hyper Parameters: \n","worker = 0\n","wandb_project = recbole\n","shuffle = True\n","require_pow = False\n","enable_amp = False\n","enable_scaler = False\n","transform = None\n","embedding_size = 64\n","kernel_size = 3\n","block_num = 5\n","dilations = [1, 4]\n","reg_weight = 1e-05\n","loss_type = CE\n","numerical_features = []\n","discretization = None\n","kg_reverse_r = False\n","entity_kg_num_interval = [0,inf)\n","relation_kg_num_interval = [0,inf)\n","MODEL_TYPE = ModelType.SEQUENTIAL\n","device = cpu\n","neg_sampling = None\n","verbose = -1\n","MODEL_INPUT_TYPE = InputType.POINTWISE\n","eval_type = EvaluatorType.RANKING\n","single_spec = True\n","local_rank = 0\n","valid_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","test_neg_sample_args = {'distribution': 'uniform', 'sample_num': 'none'}\n","\n","\n","12 Dec 08:43 INFO recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","recbox_data\n","The number of users: 13355\n","Average actions of users: 63.815710648494836\n","The number of items: 3294\n","Average actions of items: 258.78985727300335\n","The number of inters: 852195\n","The sparsity of the dataset: 98.06281322904924%\n","Remain Fields: ['user_id', 'item_id', 'timestamp']\n","12 Dec 08:43 INFO [Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","[Training]: train_batch_size = [2048] train_neg_sample_args: [{'distribution': 'none', 'sample_num': 'none', 'alpha': 'none', 'dynamic': False, 'candidate_num': 0}]\n","12 Dec 08:43 INFO [Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","[Evaluation]: eval_batch_size = [4096] eval_args: [{'split': {'RS': [9, 0, 1]}, 'order': 'TO', 'group_by': 'user', 'mode': {'valid': 'full', 'test': 'full'}}]\n","12 Dec 08:43 INFO NextItNet(\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (residual_blocks): Sequential(\n"," (0): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (1): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (2): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (3): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (4): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (5): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (6): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (7): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (8): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (9): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," )\n"," (final_layer): Linear(in_features=64, out_features=64, bias=True)\n"," (loss_fct): CrossEntropyLoss()\n"," (reg_loss): RegLoss()\n",")\n","Trainable parameters: 464576\n","NextItNet(\n"," (item_embedding): Embedding(3294, 64, padding_idx=0)\n"," (residual_blocks): Sequential(\n"," (0): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (1): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (2): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (3): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (4): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (5): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (6): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (7): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (8): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(2, 2))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," (9): ResidualBlock_b(\n"," (conv1): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(4, 4))\n"," (ln1): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," (conv2): Conv2d(64, 64, kernel_size=(1, 3), stride=(1, 1), dilation=(8, 8))\n"," (ln2): LayerNorm((64,), eps=1e-08, elementwise_affine=True)\n"," )\n"," )\n"," (final_layer): Linear(in_features=64, out_features=64, bias=True)\n"," (loss_fct): CrossEntropyLoss()\n"," (reg_loss): RegLoss()\n",")\n","Trainable parameters: 464576\n","12 Dec 08:43 INFO FLOPs: 12423360.0\n","FLOPs: 12423360.0\n","12 Dec 10:08 INFO epoch 0 training [time: 5095.82s, train loss: 2732.6105]\n","epoch 0 training [time: 5095.82s, train loss: 2732.6105]\n","12 Dec 10:08 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 11:33 INFO epoch 1 training [time: 5097.32s, train loss: 2601.4325]\n","epoch 1 training [time: 5097.32s, train loss: 2601.4325]\n","12 Dec 11:33 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 13:31 INFO epoch 2 training [time: 7077.48s, train loss: 2554.3704]\n","epoch 2 training [time: 7077.48s, train loss: 2554.3704]\n","12 Dec 13:31 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 14:39 INFO epoch 3 training [time: 4117.24s, train loss: 2529.2335]\n","epoch 3 training [time: 4117.24s, train loss: 2529.2335]\n","12 Dec 14:39 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 17:57 INFO epoch 4 training [time: 11838.39s, train loss: 2512.5584]\n","epoch 4 training [time: 11838.39s, train loss: 2512.5584]\n","12 Dec 17:57 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 18:42 INFO epoch 5 training [time: 2724.19s, train loss: 2497.9890]\n","epoch 5 training [time: 2724.19s, train loss: 2497.9890]\n","12 Dec 18:42 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 19:30 INFO epoch 6 training [time: 2856.86s, train loss: 2485.4469]\n","epoch 6 training [time: 2856.86s, train loss: 2485.4469]\n","12 Dec 19:30 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 20:13 INFO epoch 7 training [time: 2609.22s, train loss: 2474.9533]\n","epoch 7 training [time: 2609.22s, train loss: 2474.9533]\n","12 Dec 20:13 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 21:18 INFO epoch 8 training [time: 3904.76s, train loss: 2465.3467]\n","epoch 8 training [time: 3904.76s, train loss: 2465.3467]\n","12 Dec 21:18 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 22:13 INFO epoch 9 training [time: 3272.21s, train loss: 2456.8602]\n","epoch 9 training [time: 3272.21s, train loss: 2456.8602]\n","12 Dec 22:13 INFO Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Saving current: saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 22:13 INFO Loading model structure and parameters from saved/NextItNet-Dec-12-2023_08-43-30.pth\n","Loading model structure and parameters from saved/NextItNet-Dec-12-2023_08-43-30.pth\n","12 Dec 22:16 INFO The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 11.80 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.74 G/8.00 G |\n","+-------------+---------------+\n","The running environment of this training is as follows:\n","+-------------+---------------+\n","| Environment | Usage |\n","+=============+===============+\n","| CPU | 11.80 % |\n","+-------------+---------------+\n","| GPU | 0.0 / 0.0 |\n","+-------------+---------------+\n","| Memory | 0.74 G/8.00 G |\n","+-------------+---------------+\n","12 Dec 22:16 INFO best valid : None\n","best valid : None\n","12 Dec 22:16 INFO test result: OrderedDict([('recall@10', 0.0922), ('mrr@10', 0.0329), ('ndcg@10', 0.0466), ('hit@10', 0.0922), ('precision@10', 0.0092)])\n","test result: OrderedDict([('recall@10', 0.0922), ('mrr@10', 0.0329), ('ndcg@10', 0.0466), ('hit@10', 0.0922), ('precision@10', 0.0092)])\n"]},{"name":"stdout","output_type":"stream","text":["It took 814.55 mins\n","{'best_valid_score': -inf, 'valid_score_bigger': True, 'best_valid_result': None, 'test_result': OrderedDict([('recall@10', 0.0922), ('mrr@10', 0.0329), ('ndcg@10', 0.0466), ('hit@10', 0.0922), ('precision@10', 0.0092)])}\n","CPU times: user 1d 24min 58s, sys: 13h 28min 5s, total: 1d 13h 53min 4s\n","Wall time: 1d 8h 36min 2s\n"]}],"source":["%%time\n","model_list = [\"CORE\", \"LightSANs\", \"NextItNet\",] \n","\n","parameter_dict[\"train_neg_sample_args\"] = None\n","\n","for model_name in model_list:\n"," print(f\"running {model_name}...\")\n"," start = time.time()\n"," result = run_recbole(model=model_name, dataset = 'recbox_data', config_dict = parameter_dict)\n"," t = time.time() - start\n"," print(f\"It took {t/60:.2f} mins\")\n"," print(result)"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.13"},"papermill":{"default_parameters":{},"duration":27491.154881,"end_time":"2022-11-28T00:11:27.624787","environment_variables":{},"exception":null,"input_path":"__notebook__.ipynb","output_path":"__notebook__.ipynb","parameters":{},"start_time":"2022-11-27T16:33:16.469906","version":"2.3.4"}},"nbformat":4,"nbformat_minor":5}
diff --git a/service/api/views.py b/service/api/views.py
index 24cf4a7f..95178c1e 100644
--- a/service/api/views.py
+++ b/service/api/views.py
@@ -1,20 +1,38 @@
from typing import List
-
-from fastapi import APIRouter, FastAPI, Request
+from fastapi import APIRouter, Depends, FastAPI, Request
from pydantic import BaseModel
-
-from service.api.exceptions import UserNotFoundError
+import dill
+from service.api.exceptions import ModelNotFoundError, UnauthorizedUserError, UserNotFoundError
from service.log import app_logger
+import pandas as pd
+from service.models import recommend_popular
+
+
+# load predictions of dssm model
+dssm_preds = pd.read_csv("dssm_predictions.csv")
+dssm_preds.item_id = dssm_preds.item_id.apply(lambda x: [int(i) for i in x[1:-1].split(", ")])
+
+
+# get popular recommendations
+interactions = pd.read_csv('data/interactions.csv')
+interactions['last_watch_dt'] = pd.to_datetime(interactions['last_watch_dt'])
+interactions.rename(
+ columns={
+ 'last_watch_dt': 'datetime',
+ 'total_dur': 'weight',
+ },
+ inplace=True,
+ )
+popular_recs = recommend_popular(interactions)
+popular_recs_30 = recommend_popular(interactions, days = 30)
class RecoResponse(BaseModel):
user_id: int
items: List[int]
-
router = APIRouter()
-
@router.get(
path="/health",
tags=["Health"],
@@ -23,26 +41,32 @@ async def health() -> str:
return "I am alive"
+
@router.get(
path="/reco/{model_name}/{user_id}",
tags=["Recommendations"],
- response_model=RecoResponse,
+ response_model=RecoResponse
)
async def get_reco(
- request: Request,
- model_name: str,
- user_id: int,
-) -> RecoResponse:
- app_logger.info(f"Request for model: {model_name}, user_id: {user_id}")
-
- # Write your code here
+ request: Request,
+ model_name: str,
+ user_id: int,
+ # token=Depends(bearer)
+ ) -> RecoResponse:
+ # app_logger.info(f"Request for model: {model_name}, user_id: {user_id}")
+ app_logger.info(f"Request for model: {model_name}")
+ app_logger.info(f"Request for user: {user_id}")
if user_id > 10**9:
raise UserNotFoundError(error_message=f"User {user_id} not found")
+
+ if model_name == "DSSM":
+ try:
+ recs_list = dssm_preds[dssm_preds.user_id == user_id].item_id.values[0]
+ except:
+ recs_list = popular_recs_30
- k_recs = request.app.state.k_recs
- reco = list(range(k_recs))
- return RecoResponse(user_id=user_id, items=reco)
+ return RecoResponse(user_id=user_id, items=recs_list)
def add_views(app: FastAPI) -> None:
diff --git a/userknn.py b/userknn.py
new file mode 100644
index 00000000..e7cf55ab
--- /dev/null
+++ b/userknn.py
@@ -0,0 +1,112 @@
+from typing import Dict
+from collections import Counter
+
+import pandas as pd
+import numpy as np
+import scipy as sp
+from implicit.nearest_neighbours import ItemItemRecommender
+
+
+class UserKnn():
+ """Class for fit-perdict UserKNN model
+ based on ItemKNN model from implicit.nearest_neighbours
+ """
+
+ def __init__(self, model: ItemItemRecommender, N_users: int = 50):
+ self.N_users = N_users
+ self.model = model
+ self.is_fitted = False
+
+ def get_mappings(self, train):
+ self.users_inv_mapping = dict(enumerate(train['user_id'].unique()))
+ self.users_mapping = {v: k for k, v in self.users_inv_mapping.items()}
+
+ self.items_inv_mapping = dict(enumerate(train['item_id'].unique()))
+ self.items_mapping = {v: k for k, v in self.items_inv_mapping.items()}
+
+ def get_matrix(self, df: pd.DataFrame,
+ user_col: str = 'user_id',
+ item_col: str = 'item_id',
+ weight_col: str = None,
+ users_mapping: Dict[int, int] = None,
+ items_mapping: Dict[int, int] = None):
+
+ if weight_col:
+ weights = df[weight_col].astype(np.float32)
+ else:
+ weights = np.ones(len(df), dtype=np.float32)
+
+ self.interaction_matrix = sp.sparse.coo_matrix((
+ weights,
+ (
+ df[item_col].map(self.items_mapping.get),
+ df[user_col].map(self.users_mapping.get)
+ )
+ ))
+
+ self.watched = df\
+ .groupby(user_col, as_index=False)\
+ .agg({item_col: list})\
+ .rename(columns={user_col: 'sim_user_id'})
+
+ return self.interaction_matrix
+
+ def idf(self, n: int, x: float):
+ return np.log((1 + n) / (1 + x) + 1)
+
+ def _count_item_idf(self, df: pd.DataFrame):
+ item_cnt = Counter(df['item_id'].values)
+ item_idf = pd.DataFrame.from_dict(item_cnt, orient='index',
+ columns=['doc_freq']).reset_index()
+ item_idf['idf'] = item_idf['doc_freq'].apply(lambda x: self.idf(self.n, x))
+ self.item_idf = item_idf
+
+ def fit(self, train: pd.DataFrame):
+ self.user_knn = self.model
+ self.get_mappings(train)
+ self.weights_matrix = self.get_matrix(train,
+ users_mapping=self.users_mapping,
+ items_mapping=self.items_mapping)
+
+ self.n = train.shape[0]
+ self._count_item_idf(train)
+
+ self.user_knn.fit(self.weights_matrix)
+ self.is_fitted = True
+
+ def _generate_recs_mapper(self, model: ItemItemRecommender, user_mapping: Dict[int, int],
+ user_inv_mapping: Dict[int, int], N: int):
+ def _recs_mapper(user):
+ user_id = self.users_mapping[user]
+ users, sim = model.similar_items(user_id, N=N)
+ return [self.users_inv_mapping[user] for user in users], sim
+ return _recs_mapper
+
+ def predict(self, test: pd.DataFrame, N_recs: int = 10):
+
+ if not self.is_fitted:
+ raise ValueError("Please call fit before predict")
+
+ mapper = self._generate_recs_mapper(
+ model=self.user_knn,
+ user_mapping=self.users_mapping,
+ user_inv_mapping=self.users_inv_mapping,
+ N=self.N_users
+ )
+
+ recs = pd.DataFrame({'user_id': test['user_id'].unique()})
+ recs['sim_user_id'], recs['sim'] = zip(*recs['user_id'].map(mapper))
+ recs = recs.set_index('user_id').apply(pd.Series.explode).reset_index()
+
+ recs = recs[~(recs['user_id'] == recs['sim_user_id'])]\
+ .merge(self.watched, on=['sim_user_id'], how='left')\
+ .explode('item_id')\
+ .sort_values(['user_id', 'sim'], ascending=False)\
+ .drop_duplicates(['user_id', 'item_id'], keep='first')\
+ .merge(self.item_idf, left_on='item_id', right_on='index', how='left')
+
+ recs['score'] = recs['sim'] * recs['idf']
+ recs = recs.sort_values(['user_id', 'score'], ascending=False)
+ recs['rank'] = recs.groupby('user_id').cumcount() + 1
+ return recs[recs['rank'] <= N_recs][['user_id', 'item_id', 'score', 'rank']]
+
\ No newline at end of file