23
23
"metadata" : {},
24
24
"outputs" : [],
25
25
"source" : [
26
- " import torch\n " ,
26
+ " import torchconda install faiss-gpu cudatoolkit=10.0 -c pytorch # For CUDA10\n " ,
27
+ " \n " ,
27
28
" from torch.utils.data import Dataset, DataLoader\n " ,
28
29
" import torch.nn as nn\n " ,
29
30
" import torch.nn.functional as F\n " ,
37
38
"outputs" : [],
38
39
"source" : [
39
40
" import numpy as np\n " ,
40
- " from scipy.spatial import distance\n " ,
41
41
" from tqdm.auto import tqdm\n " ,
42
42
" import pickle\n " ,
43
43
" import gc\n " ,
44
44
" import json\n " ,
45
- " import h5py\n " ,
46
45
" import pandas as pd\n " ,
47
46
" \n " ,
48
47
" from IPython.display import clear_output\n " ,
70
69
"metadata" : {},
71
70
"outputs" : [
72
71
{
72
+ "output_type" : " execute_result" ,
73
73
"data" : {
74
- "text/plain" : [
75
- " <All keys matched successfully>"
76
- ]
74
+ "text/plain" : " <All keys matched successfully>"
77
75
},
78
- "execution_count" : 3 ,
79
76
"metadata" : {},
80
- "output_type " : " execute_result "
77
+ "execution_count " : 3
81
78
}
82
79
],
83
80
"source" : [
91
88
"cell_type" : " code" ,
92
89
"execution_count" : 4 ,
93
90
"metadata" : {},
94
- "outputs" : [
95
- {
96
- "data" : {
97
- "application/vnd.jupyter.widget-view+json" : {
98
- "model_id" : " 5d711e3a8edd4108a6bdd94a47712092" ,
99
- "version_major" : 2 ,
100
- "version_minor" : 0
101
- },
102
- "text/plain" : [
103
- " HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
104
- ]
105
- },
106
- "metadata" : {},
107
- "output_type" : " display_data"
108
- },
109
- {
110
- "name" : " stdout" ,
111
- "output_type" : " stream" ,
112
- "text" : [
113
- " \n "
114
- ]
115
- },
116
- {
117
- "data" : {
118
- "application/vnd.jupyter.widget-view+json" : {
119
- "model_id" : " 3ea67abab15a4e08bb0b13e7b342e34d" ,
120
- "version_major" : 2 ,
121
- "version_minor" : 0
122
- },
123
- "text/plain" : [
124
- " HBox(children=(IntProgress(value=0, max=20000263), HTML(value='')))"
125
- ]
126
- },
127
- "metadata" : {},
128
- "output_type" : " display_data"
129
- },
130
- {
131
- "name" : " stdout" ,
132
- "output_type" : " stream" ,
133
- "text" : [
134
- " \n "
135
- ]
136
- },
137
- {
138
- "data" : {
139
- "application/vnd.jupyter.widget-view+json" : {
140
- "model_id" : " d8120ad7d9724cb9b1c883d80ba2c24a" ,
141
- "version_major" : 2 ,
142
- "version_minor" : 0
143
- },
144
- "text/plain" : [
145
- " HBox(children=(IntProgress(value=0, max=138493), HTML(value='')))"
146
- ]
147
- },
148
- "metadata" : {},
149
- "output_type" : " display_data"
150
- },
151
- {
152
- "name" : " stdout" ,
153
- "output_type" : " stream" ,
154
- "text" : [
155
- " \n "
156
- ]
157
- }
158
- ],
91
+ "outputs" : [],
159
92
"source" : [
93
+ " frame_size = 10\n " ,
94
+ " batch_size = 1\n " ,
160
95
" # embeddgings: https://drive.google.com/open?id=1EQ_zXBR3DKpmJR3jBgLvt-xoOvArGMsL\n " ,
161
- " env = recnn.data.FrameEnv('../../data/embeddings/ml20_pca128.pkl',\n " ,
162
- " '../../data/ml-20m/ratings.csv', 10, 1)"
96
+ " dirs = recnn.data.env.DataPath(\n " ,
97
+ " base=\" ../../data/\" ,\n " ,
98
+ " embeddings=\" embeddings/ml20_pca128.pkl\" ,\n " ,
99
+ " ratings=\" ml-20m/ratings.csv\" ,\n " ,
100
+ " cache=\" cache/frame_env.pkl\" , # cache will generate after you run\n " ,
101
+ " use_cache=True\n " ,
102
+ " )\n " ,
103
+ " env = recnn.data.env.FrameEnv(dirs, frame_size, batch_size)"
163
104
]
164
105
},
165
106
{
166
107
"cell_type" : " code" ,
167
- "execution_count" : 6 ,
108
+ "execution_count" : 5 ,
168
109
"metadata" : {},
169
110
"outputs" : [],
170
111
"source" : [
171
112
" test_batch = next(iter(env.test_dataloader))\n " ,
172
113
" state, action, reward, next_state, done = recnn.data.get_base_batch(test_batch)"
173
114
]
174
115
},
175
- {
176
- "cell_type" : " code" ,
177
- "execution_count" : 7 ,
178
- "metadata" : {},
179
- "outputs" : [],
180
- "source" : [
181
- " def rank(gen_action, metric):\n " ,
182
- " scores = []\n " ,
183
- " for i in movie_embeddings_key_dict.keys():\n " ,
184
- " scores.append([i, metric(movie_embeddings_key_dict[i], gen_action)])\n " ,
185
- " scores = list(sorted(scores, key = lambda x: x[1]))\n " ,
186
- " scores = scores[:10]\n " ,
187
- " ids = [i[0] for i in scores]\n " ,
188
- " dist = [i[1] for i in scores]\n " ,
189
- " \n " ,
190
- " return ids, dist"
191
- ]
192
- },
193
116
{
194
117
"cell_type" : " markdown" ,
195
118
"metadata" : {},
199
122
},
200
123
{
201
124
"cell_type" : " code" ,
202
- "execution_count" : 9 ,
125
+ "execution_count" : 7 ,
203
126
"metadata" : {},
204
127
"outputs" : [
205
128
{
206
- "name" : " stderr" ,
207
- "output_type" : " stream" ,
208
- "text" : [
209
- " /home/dev/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:7: FutureWarning: arrays to stack must be passed as a \" sequence\" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.\n " ,
210
- " import sys\n "
129
+ "output_type" : " error" ,
130
+ "ename" : " ModuleNotFoundError" ,
131
+ "evalue" : " No module named 'faiss'" ,
132
+ "traceback" : [
133
+ " \u001b [0;31m---------------------------------------------------------------------------\u001b [0m" ,
134
+ " \u001b [0;31mModuleNotFoundError\u001b [0m Traceback (most recent call last)" ,
135
+ "\u001b[0;32m<ipython-input-7-f4db3fa501af>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mfaiss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;31m# test indexes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mindexL2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfaiss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIndexFlatL2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mindexIP\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfaiss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIndexFlatIP\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mindexCOS\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfaiss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIndexFlatIP\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
136
+ " \u001b [0;31mModuleNotFoundError\u001b [0m: No module named 'faiss'"
211
137
]
212
138
}
213
139
],
218
144
" indexIP = faiss.IndexFlatIP(128)\n " ,
219
145
" indexCOS = faiss.IndexFlatIP(128)\n " ,
220
146
" \n " ,
221
- " mov_mat = np.stack( env.movie_embeddings_key_dict.values() ).astype('float32')\n " ,
147
+ " mov_mat = env.base.embeddings.detach().cpu().numpy( ).astype('float32')\n " ,
222
148
" indexL2.add(mov_mat)\n " ,
223
149
" indexIP.add(mov_mat)\n " ,
224
150
" indexCOS.add(normalize(mov_mat, axis=1, norm='l2'))\n " ,
1247
1173
"name" : " python" ,
1248
1174
"nbconvert_exporter" : " python" ,
1249
1175
"pygments_lexer" : " ipython3" ,
1250
- "version" : " 3.7.3 "
1176
+ "version" : " 3.8.5-final "
1251
1177
}
1252
1178
},
1253
1179
"nbformat" : 4 ,
1254
1180
"nbformat_minor" : 2
1255
- }
1181
+ }
0 commit comments