Skip to content

Commit b40a0a8

Browse files
committed
add caching
1 parent fb32a58 commit b40a0a8

File tree

1 file changed

+114
-30
lines changed

1 file changed

+114
-30
lines changed

recommendations/create_prediction.py

+114-30
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import List
22
from typing import Optional
33

4-
import os
54
import click
65
import logging
76
import pandas as pd
@@ -26,7 +25,11 @@ def __init__(
2625
self.config = config
2726
self.persistence_path = self.config.PATHS.predictions
2827
self.load_model(input_file, input_model)
29-
28+
self.user_df = self.model.dataset.user_df
29+
self.item_df = self.model.dataset.item_df
30+
self.interaction_df = self.model.dataset.interaction_df
31+
self.item_no_interactions_list = self.model.dataset.item_no_interactions_list
32+
3033
def load_model(self, input_file: Optional[str] = None, input_model: Optional[Model] = None) -> None:
3134
logging.info(f'Loading model...')
3235

@@ -41,36 +44,91 @@ def load_model(self, input_file: Optional[str] = None, input_model: Optional[Mod
4144
self.input_file = click.format_filename(input_file)
4245
self.model = helpers.load_input_file(self.input_file) # type: ignore
4346

47+
def create_scores_matrix(
48+
self,
49+
is_cab: bool
50+
) -> None:
51+
'''
52+
Formula: r^ui = f(q^u . p^i + b^u + b^i)
53+
-------
54+
r^ui = Prediction for user u and item i
55+
q^u = User representations
56+
p^i = Item representations
57+
b^u = User feature bias
58+
b^i = Item feature bias
59+
'''
60+
if is_cab:
61+
model = self.model.cab_model
62+
else:
63+
model = self.model.model
64+
user_bias, user_latent_repr = model.get_user_representations()
65+
item_bias, item_latent_repr = model.get_item_representations()
66+
logging.info(
67+
f'Logging user latent features before broadcasting\n'
68+
f'Type: {type(user_latent_repr)}\n'
69+
f'Shape: {user_latent_repr.shape}'
70+
)
71+
logging.info(
72+
f'Logging item latent features before broadcasting\n'
73+
f'Type: {type(item_latent_repr)}\n'
74+
f'Shape: {item_latent_repr.shape}'
75+
)
76+
logging.info(
77+
f'Logging user bias features before broadcasting\n'
78+
f'Type: {type(user_bias)}\n'
79+
f'Shape: {user_bias.shape}'
80+
)
81+
logging.info(
82+
f'Logging item bias features before broadcasting\n'
83+
f'Type: {type(item_bias)}\n'
84+
f'Shape: {item_bias.shape}'
85+
)
86+
87+
user_bias = user_bias[:, np.newaxis]
88+
item_bias = item_bias[:, np.newaxis]
89+
90+
self.dot_product = user_latent_repr @ item_latent_repr.T + user_bias + item_bias.T
91+
4492
def get_lightfm_recommendation(
4593
self,
4694
user_index: int,
95+
use_precomputed_scores: bool
4796
) -> List[int]:
4897
'''
98+
Top-picks
99+
---------
49100
Main function that creates user-variant recommendation lists.
50101
'''
51-
model = self.model
52-
interaction_df = self.model.dataset.interaction_df
53-
item_df = self.model.dataset.item_df
54-
user_df = self.model.dataset.user_df
55-
item_no_interactions_list = self.model.dataset.item_no_interactions_list
56102
n_users, n_items = self.model.dataset.interactions.shape
57103

58-
is_new_user = (user_index not in list(user_df['user_id']))
104+
is_new_user = (user_index not in list(self.user_df['user_id']))
105+
logging.info('logging in create_prediction file')
106+
logging.info('logging all users')
107+
logging.info(self.user_df)
108+
logging.info(self.user_df['user_id'])
109+
logging.info(is_new_user)
110+
logging.info(self.model.dataset)
111+
logging.info(type(self.model.dataset.item_features))
112+
logging.info(self.model.dataset.item_features)
113+
logging.info(self.model.dataset.item_features.shape)
59114

60115
if is_new_user:
61116
# TODO: Cold-start recommendation
62117
logging.info('Getting prediction for new user ~')
63118
else:
64-
scores = self.model.model.predict(
65-
user_index,
66-
item_ids=np.arange(n_items),
67-
item_features=self.model.dataset.item_features
68-
)
69-
item_df['scores'] = scores
70-
top_items = item_df['product_id'][np.argsort(-scores)]
119+
if use_precomputed_scores:
120+
scores = self.dot_product[user_index]
121+
else:
122+
scores = self.model.model.predict(
123+
user_index,
124+
item_ids=np.arange(n_items),
125+
item_features=self.model.dataset.item_features
126+
)
127+
self.item_df['scores'] = scores
128+
top_items = self.item_df['product_id'][np.argsort(-scores)]
71129
top_items_idx = top_items.index
72-
73-
top_reccs_df = item_df.iloc[top_items_idx, :].head(10)
130+
131+
top_reccs_df = self.item_df.iloc[top_items_idx, :].head(10)
74132
topReccsTable = PrettyTable(['product_id', 'product_name', 'aisle', 'department', 'num', 'scores'])
75133
topReccsTable.add_row([
76134
top_reccs_df['product_id'],
@@ -87,24 +145,36 @@ def get_lightfm_recommendation(
87145
def get_similar_items(
88146
self,
89147
product_id: int,
148+
rec_type: int
90149
) -> pd.DataFrame:
91150
'''
92-
Main function that creates similar variant recommendation lists.
151+
Function that creates recommendation lists.
152+
153+
The intuition behind using less components is reducing the number of latent factors
154+
that can be inferred. And, by excluding item features for the CAB model, recommendations
155+
will be less based off explicit features such as `aisle` and `department`.
156+
-------------------
157+
type:
158+
1 - Similar Items [DEFAULT_PARAMS]
159+
2 - Complement Items [CAB_PARAMS]
93160
'''
94-
item_df = self.model.dataset.item_df
95-
96-
annoy_model = AnnoyIndex(self.model.config.ANNOY_PARAMS['emb_dim'])
97-
annoy_model.load(self.config.PATHS.models + '/item.ann')
161+
logging.info(f'Logging recommendations for {self.model.config.ANNOY_PARAMS[rec_type]}')
162+
if rec_type == 1:
163+
annoy_model = AnnoyIndex(self.model.config.LIGHTFM_PARAMS['no_components'])
164+
annoy_model.load(self.config.PATHS.models + '/item.ann')
165+
elif rec_type == 2:
166+
annoy_model = AnnoyIndex(self.model.config.LIGHTFM_CAB_PARAMS['no_components'])
167+
annoy_model.load(self.config.PATHS.models + '/item_cab.ann')
98168
similar_variants = annoy_model.get_nns_by_item(
99169
product_id,
100170
self.model.config.ANNOY_PARAMS['nn_count'],
101171
search_k=-1,
102172
include_distances=False
103173
)
104-
logging.info('inside sv')
174+
105175
logging.info(type(similar_variants))
106176
logging.info(similar_variants)
107-
similar_variants_df = item_df.iloc[similar_variants, :]
177+
similar_variants_df = self.item_df.iloc[similar_variants, :]
108178

109179
similarVariantsTable = PrettyTable(['product_id', 'product_name', 'aisle', 'department', 'num'])
110180
similarVariantsTable.add_row([
@@ -114,26 +184,40 @@ def get_similar_items(
114184
similar_variants_df['department'],
115185
similar_variants_df['num']
116186
])
117-
logging.info(f'Similar Variants Data: \n{similarVariantsTable}')
187+
logging.info(f'{self.model.config.ANNOY_PARAMS[rec_type]} Data: \n{similarVariantsTable}')
118188

119189
return similar_variants_df
120190

191+
def cache_top_picks(self) -> None:
192+
'''
193+
Function to store top recommendations for each user into a dictionary.
194+
'''
195+
logging.info('Getting Top-Picks for each User')
196+
user_toppicks_cache = {}
197+
for uid in range(len(self.user_df)):
198+
logging.info(f'Caching Top-Picks Recommendation for User {uid}')
199+
user_toppicks_cache[uid] = self.get_lightfm_recommendation(user_index=uid, use_precomputed_scores=False)
200+
201+
self.user_toppicks_cache = user_toppicks_cache
202+
121203

122204
@click.command()
123205
@click.option('--input_file', default=None, type=click.Path(exists=True, dir_okay=False))
124-
@click.option('--user', default=None, type=click.Path(exists=True, dir_okay=False))
125206
@click.option('--config', default='production')
126-
def main(input_file: str, config: str, user: int) -> None:
207+
def main(input_file: str, config: str) -> None:
127208

128209
logging.info("Let's make a prediction!")
129210
configuration = helpers.get_configuration(config, prediction_configurations)
130211

131212
predictor = UserItemPrediction(config=configuration, input_file=None)
132-
predictor.get_similar_items(configuration.DEFAULT_ITEM_EG)
133-
predictor.get_lightfm_recommendation(configuration.DEFAULT_USER_EG)
213+
# predictor.create_scores_matrix(is_cab=False)
214+
# predictor.create_scores_matrix(is_cab=True)
215+
predictor.get_similar_items(product_id=configuration.DEFAULT_ITEM_EG, rec_type=1)
216+
predictor.get_similar_items(product_id=configuration.DEFAULT_ITEM_EG, rec_type=2)
217+
predictor.get_lightfm_recommendation(user_index=configuration.DEFAULT_USER_EG, use_precomputed_scores=False)
134218

135219

136220
if __name__ == "__main__":
137221
logger = helpers.get_logger()
138222

139-
main()
223+
main()

0 commit comments

Comments
 (0)