Skip to content

Commit f98b4fe

Browse files
committed
fixed distribution samples in examples
1 parent 2089761 commit f98b4fe

13 files changed

+2201
-4953
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
Cloning from github
2+
===================
3+
4+
Pro tip: clone without history (unless you need it)::
5+
6+
git clone --depth 1 [email protected]:awarebayes/RecNN.git
7+
8+
Create ENV and install deps::
9+
10+
conda create --name recnn
11+
conda activate recnn
12+
cd RecNN
13+
pip install -r requirements.txt
14+
15+
Download data from the donwloads section
16+
17+
Start jupyter notebook and jump to the examples folder ::
18+
19+
jupyter-notebook .
20+
21+
Here is how my project directories looks like (shallow)::
22+
23+
RecNN
24+
├── .circleci
25+
├── data
26+
├── docs
27+
├── examples
28+
├── .git
29+
├── .gitignore
30+
├── LICENSE
31+
├── models
32+
├── readme.md
33+
├── recnn
34+
├── requirements.txt
35+
├── res
36+
├── runs
37+
├── setup.cfg
38+
└── setup.py
39+
40+
Here is the data directory (ignore the cache)::
41+
42+
data
43+
├── cache
44+
│ ├── frame_env.pkl
45+
│ └── frame_env_truncated.pkl
46+
├── embeddings
47+
│ └── ml20_pca128.pkl
48+
└── ml-20m
49+
├── genome-scores.csv
50+
├── genome-tags.csv
51+
├── links.csv
52+
├── movies.csv
53+
├── ratings.csv
54+
├── README.txt
55+
└── tags.csv

examples/1. Vanilla RL/2. DDPG.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@
230230
" losses = ddpg_update(test_batch, params, learn=False, step=step)\n",
231231
" \n",
232232
" gen_actions = debug['next_action']\n",
233-
" true_actions = env.embeddings.detach().cpu().numpy()\n",
233+
" true_actions = env.base.embeddings.detach().cpu().numpy()\n",
234234
" \n",
235235
" f = plotter.kde_reconstruction_error(ad, gen_actions, true_actions, cuda)\n",
236236
" writer.add_figure('rec_error',f, losses['step'])\n",
@@ -488,7 +488,7 @@
488488
],
489489
"source": [
490490
"gen_actions = debug['next_action']\n",
491-
"true_actions = env.embeddings.numpy()\n",
491+
"true_actions = env.base.embeddings.numpy()\n",
492492
"\n",
493493
"\n",
494494
"ad = recnn.nn.AnomalyDetector().to(cuda)\n",

examples/1. Vanilla RL/3. TD3.ipynb

+49-116
Large diffs are not rendered by default.

examples/[Results]/1. Ranking.ipynb

+2,043
Large diffs are not rendered by default.

examples/_ Results/2. Diversity Test (Indexes).ipynb examples/[Results]/2. Diversity Test (Indexes).ipynb

+29-103
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
"metadata": {},
2424
"outputs": [],
2525
"source": [
26-
"import torch\n",
26+
"import torchconda install faiss-gpu cudatoolkit=10.0 -c pytorch # For CUDA10\n",
27+
"\n",
2728
"from torch.utils.data import Dataset, DataLoader\n",
2829
"import torch.nn as nn\n",
2930
"import torch.nn.functional as F\n",
@@ -37,12 +38,10 @@
3738
"outputs": [],
3839
"source": [
3940
"import numpy as np\n",
40-
"from scipy.spatial import distance\n",
4141
"from tqdm.auto import tqdm\n",
4242
"import pickle\n",
4343
"import gc\n",
4444
"import json\n",
45-
"import h5py\n",
4645
"import pandas as pd\n",
4746
"\n",
4847
"from IPython.display import clear_output\n",
@@ -70,14 +69,12 @@
7069
"metadata": {},
7170
"outputs": [
7271
{
72+
"output_type": "execute_result",
7373
"data": {
74-
"text/plain": [
75-
"<All keys matched successfully>"
76-
]
74+
"text/plain": "<All keys matched successfully>"
7775
},
78-
"execution_count": 3,
7976
"metadata": {},
80-
"output_type": "execute_result"
77+
"execution_count": 3
8178
}
8279
],
8380
"source": [
@@ -91,105 +88,31 @@
9188
"cell_type": "code",
9289
"execution_count": 4,
9390
"metadata": {},
94-
"outputs": [
95-
{
96-
"data": {
97-
"application/vnd.jupyter.widget-view+json": {
98-
"model_id": "5d711e3a8edd4108a6bdd94a47712092",
99-
"version_major": 2,
100-
"version_minor": 0
101-
},
102-
"text/plain": [
103-
"HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
104-
]
105-
},
106-
"metadata": {},
107-
"output_type": "display_data"
108-
},
109-
{
110-
"name": "stdout",
111-
"output_type": "stream",
112-
"text": [
113-
"\n"
114-
]
115-
},
116-
{
117-
"data": {
118-
"application/vnd.jupyter.widget-view+json": {
119-
"model_id": "3ea67abab15a4e08bb0b13e7b342e34d",
120-
"version_major": 2,
121-
"version_minor": 0
122-
},
123-
"text/plain": [
124-
"HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
125-
]
126-
},
127-
"metadata": {},
128-
"output_type": "display_data"
129-
},
130-
{
131-
"name": "stdout",
132-
"output_type": "stream",
133-
"text": [
134-
"\n"
135-
]
136-
},
137-
{
138-
"data": {
139-
"application/vnd.jupyter.widget-view+json": {
140-
"model_id": "d8120ad7d9724cb9b1c883d80ba2c24a",
141-
"version_major": 2,
142-
"version_minor": 0
143-
},
144-
"text/plain": [
145-
"HBox(children=(IntProgress(value=0, max=138493), HTML(value='')))"
146-
]
147-
},
148-
"metadata": {},
149-
"output_type": "display_data"
150-
},
151-
{
152-
"name": "stdout",
153-
"output_type": "stream",
154-
"text": [
155-
"\n"
156-
]
157-
}
158-
],
91+
"outputs": [],
15992
"source": [
93+
"frame_size = 10\n",
94+
"batch_size = 1\n",
16095
"# embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n",
161-
"env = recnn.data.FrameEnv('../../data/embeddings/ml20_pca128.pkl',\n",
162-
" '../../data/ml-20m/ratings.csv', 10, 1)"
96+
"dirs = recnn.data.env.DataPath(\n",
97+
" base=\"../../data/\",\n",
98+
" embeddings=\"embeddings/ml20_pca128.pkl\",\n",
99+
" ratings=\"ml-20m/ratings.csv\",\n",
100+
" cache=\"cache/frame_env.pkl\", # cache will generate after you run\n",
101+
" use_cache=True\n",
102+
")\n",
103+
"env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)"
163104
]
164105
},
165106
{
166107
"cell_type": "code",
167-
"execution_count": 6,
108+
"execution_count": 5,
168109
"metadata": {},
169110
"outputs": [],
170111
"source": [
171112
"test_batch = next(iter(env.test_dataloader))\n",
172113
"state, action, reward, next_state, done = recnn.data.get_base_batch(test_batch)"
173114
]
174115
},
175-
{
176-
"cell_type": "code",
177-
"execution_count": 7,
178-
"metadata": {},
179-
"outputs": [],
180-
"source": [
181-
"def rank(gen_action, metric):\n",
182-
" scores = []\n",
183-
" for i in movie_embeddings_key_dict.keys():\n",
184-
" scores.append([i, metric(movie_embeddings_key_dict[i], gen_action)])\n",
185-
" scores = list(sorted(scores, key = lambda x: x[1]))\n",
186-
" scores = scores[:10]\n",
187-
" ids = [i[0] for i in scores]\n",
188-
" dist = [i[1] for i in scores]\n",
189-
"\n",
190-
" return ids, dist"
191-
]
192-
},
193116
{
194117
"cell_type": "markdown",
195118
"metadata": {},
@@ -199,15 +122,18 @@
199122
},
200123
{
201124
"cell_type": "code",
202-
"execution_count": 9,
125+
"execution_count": 7,
203126
"metadata": {},
204127
"outputs": [
205128
{
206-
"name": "stderr",
207-
"output_type": "stream",
208-
"text": [
209-
"/home/dev/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:7: FutureWarning: arrays to stack must be passed as a \"sequence\" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.\n",
210-
" import sys\n"
129+
"output_type": "error",
130+
"ename": "ModuleNotFoundError",
131+
"evalue": "No module named 'faiss'",
132+
"traceback": [
133+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
134+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
135+
"\u001b[0;32m<ipython-input-7-f4db3fa501af>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mfaiss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;31m# test indexes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mindexL2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfaiss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIndexFlatL2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mindexIP\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfaiss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIndexFlatIP\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mindexCOS\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfaiss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIndexFlatIP\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
136+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'faiss'"
211137
]
212138
}
213139
],
@@ -218,7 +144,7 @@
218144
"indexIP = faiss.IndexFlatIP(128)\n",
219145
"indexCOS = faiss.IndexFlatIP(128)\n",
220146
"\n",
221-
"mov_mat = np.stack(env.movie_embeddings_key_dict.values()).astype('float32')\n",
147+
"mov_mat = env.base.embeddings.detach().cpu().numpy().astype('float32')\n",
222148
"indexL2.add(mov_mat)\n",
223149
"indexIP.add(mov_mat)\n",
224150
"indexCOS.add(normalize(mov_mat, axis=1, norm='l2'))\n",
@@ -1247,9 +1173,9 @@
12471173
"name": "python",
12481174
"nbconvert_exporter": "python",
12491175
"pygments_lexer": "ipython3",
1250-
"version": "3.7.3"
1176+
"version": "3.8.5-final"
12511177
}
12521178
},
12531179
"nbformat": 4,
12541180
"nbformat_minor": 2
1255-
}
1181+
}

0 commit comments

Comments
 (0)