Skip to content

Commit dad2c0a

Browse files
authored
Merge pull request #2 from terryyylim/terence/modeling
Terence/modeling
2 parents b57e307 + 700217e commit dad2c0a

14 files changed

+610
-20
lines changed

.flake8

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[flake8]
2+
max-line-length = 120

.gitignore

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
__pycache__/
44
*.py[cod]
55
*$py.class
6+
.mypy_cache/
67

78
# Unit test / coverage reports
89
.pytest_cache/
@@ -32,4 +33,7 @@ venv.bak/
3233

3334
# Folders
3435
**/datasets
35-
**/models
36+
**/models
37+
38+
# Local Test Run Script
39+
**localrun.py

.travis.yml

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
language: python
2+
python:
3+
- "3.6"
4+
5+
install:
6+
pip install -r requirements.txt
7+
8+
script:
9+
- flake8
10+
- mypy --config-file mypy.ini recommendations/*.py

README.md

+32
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,34 @@
11
# instacart-recommendation
22
Recommendation models using Instacart Orders.
3+
4+
### Quick Start
5+
Clone this repository:
6+
```
7+
git clone https://github.com/terryyylim/instacart-recommendation.git
8+
```
9+
10+
Change to instacart-recommendation directory
11+
```
12+
cd instacart-recommendation
13+
```
14+
15+
Create and activate a virtual environment (run `pip3 install virtualenv` first if you don't have Python virtualenv installed):
16+
```
17+
virtualenv -p python3 <desired-path>
18+
source <desired-path>/bin/activate
19+
```
20+
21+
Install the requirements:
22+
```
23+
pip install -r requirements.txt
24+
```
25+
26+
Change to recommendations directory
27+
```
28+
cd recommendations
29+
```
30+
31+
### Testing Dataset & Model Creation
32+
```
33+
python run.py
34+
```

mypy.ini

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[mypy]
2+
ignore_missing_imports = True
3+
disallow_untyped_calls = True
4+
disallow_untyped_defs = True
5+
disallow_incomplete_defs = True
6+
check_untyped_defs = True

recommendations/create_dataset.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import click
44
import logging
55
import pandas as pd
6-
import numpy as np
76
import psycopg2 as pg
87
from prettytable import PrettyTable
98
from lightfm.data import Dataset as LFMDataset
@@ -163,12 +162,14 @@ def build_lightfm_dataset(self) -> None:
163162
((rating['user_id'], rating['product_id']) for rating in ratings_list)
164163
)
165164

165+
self.n_users, self.n_items = self.interactions.shape
166+
166167
logging.info(f'Logging self.interactions @build_lightfm_dataset: \n{self.interactions}')
167168
logging.info(f'Logging self.weights @build_lightfm_dataset: \n{self.weights}')
168169
logging.info(
169170
f'The shape of self.interactions {self.interactions.shape} '
170171
f'and self.weights {self.weights.shape} represent the user-item matrix.')
171-
172+
172173

173174
@click.command()
174175
@click.option('--config', default='production', help='the deployment target')
@@ -185,4 +186,4 @@ def main(config: str) -> None:
185186
if __name__ == "__main__":
186187
logger = helpers.get_logger()
187188

188-
main()
189+
main()

recommendations/create_model.py

+254
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
from typing import List
2+
from typing import Optional
3+
4+
import logging
5+
6+
import os
7+
import click
8+
from pathlib import Path
9+
from annoy import AnnoyIndex
10+
from lightfm import LightFM
11+
from lightfm.evaluation import auc_score
12+
from lightfm.cross_validation import random_train_test_split as train_test_split
13+
14+
import helpers
15+
from create_dataset import Dataset # noqa
16+
from config import model_configurations
17+
from config import ModelConfig
18+
19+
20+
class Model:
21+
def __init__(
22+
self,
23+
config: ModelConfig,
24+
input_file: Optional[str] = None,
25+
input_dataset: Optional[Dataset] = None,
26+
) -> None:
27+
28+
self.config = config
29+
self.persistence_path = self.config.PATHS.models
30+
self.load_dataset()
31+
32+
def load_dataset(self, input_file: Optional[str] = None, input_dataset: Optional[Dataset] = None) -> None:
33+
logging.info(f'Loading dataset...')
34+
35+
if input_file is None and input_dataset is None:
36+
self.input_file = helpers.find_latest_file(self.config.PATHS.data)
37+
self.dataset = helpers.load_input_file(self.input_file) # type: ignore
38+
elif input_file is None and input_dataset is not None:
39+
self.dataset = input_dataset
40+
elif input_file is not None and input_dataset is not None:
41+
logging.warning('Both an input dataset and an input Dataset object were provided. Using the object.')
42+
elif input_file is not None and input_dataset is None:
43+
self.input_file = click.format_filename(input_file)
44+
self.dataset = helpers.load_input_file(self.input_file) # type: ignore
45+
46+
def build_model(self) -> None:
47+
"""
48+
Fits model for user-variant recommendations and similar variant recommendations.
49+
"""
50+
if hasattr(self, 'input_file'):
51+
logging.info(f'Training the main model with dataset {self.input_file}...')
52+
else:
53+
logging.info('Training the model...')
54+
55+
train_validation, test = train_test_split(
56+
self.dataset.interactions, **self.config.VALIDATION_PARAMS
57+
)
58+
train, validation = train_test_split(
59+
train_validation, **self.config.VALIDATION_PARAMS
60+
)
61+
62+
logging.info(f'train: Type; {type(train)}, Shape; {train.shape}')
63+
logging.info(f'validation: Type; {type(validation)}, Shape; {validation.shape}')
64+
logging.info(f'test: Type; {type(test)}, Shape; {test.shape}')
65+
66+
model = LightFM(**self.config.LIGHTFM_PARAMS)
67+
warp_auc: List[float] = []
68+
no_improvement_rounds = 0
69+
best_auc = 0.0
70+
epochs = self.config.FIT_PARAMS['epochs']
71+
early_stopping_rounds = self.config.FIT_PARAMS['early_stopping_rounds']
72+
73+
logging.info(
74+
f'Training model until validation AUC has not improved in {early_stopping_rounds} epochs...'
75+
)
76+
77+
for epoch in range(epochs):
78+
logging.info(f'Epoch {epoch}...')
79+
if no_improvement_rounds >= early_stopping_rounds:
80+
break
81+
82+
model.fit(
83+
interactions=train,
84+
item_features=self.dataset.item_features,
85+
epochs=self.config.FIT_PARAMS['epochs_per_round'],
86+
num_threads=self.config.FIT_PARAMS['core_count'],
87+
)
88+
warp_auc.append(
89+
auc_score(
90+
model=model,
91+
test_interactions=validation,
92+
item_features=self.dataset.item_features,
93+
).mean()
94+
)
95+
96+
if warp_auc[-1] > best_auc:
97+
best_auc = warp_auc[-1]
98+
no_improvement_rounds = 0
99+
else:
100+
no_improvement_rounds += 1
101+
102+
logging.info(f'[{epoch}]\tvalidation_warp_auc: {warp_auc[-1]}')
103+
104+
self.num_epochs = len(warp_auc) - early_stopping_rounds
105+
logging.info(f'Stopping. Best Iteration:')
106+
logging.info(
107+
f'[{self.num_epochs - 1}]\tvalidation_warp_auc: {warp_auc[self.num_epochs - 1]}'
108+
)
109+
110+
logging.info(f'Calculating AUC score on test set...')
111+
test_score = auc_score(
112+
model=model,
113+
test_interactions=test,
114+
item_features=self.dataset.item_features,
115+
).mean()
116+
logging.info(f'Test Set AUC Score: {test_score}')
117+
118+
self.model = model
119+
self.test_score = test_score
120+
121+
def build_cab_model(
122+
self
123+
) -> None:
124+
'''
125+
Fits model for complement variant recommendations. Only interaction matrices are fed
126+
into the LightFM model, without any content-based information.
127+
'''
128+
if hasattr(self, 'input_file'):
129+
logging.info(f'Training the main CAB model with dataset {self.input_file}...')
130+
else:
131+
logging.info('Training the CAB model...')
132+
133+
cab_model = LightFM(**self.config.LIGHTFM_CAB_PARAMS)
134+
self.cab_model = cab_model.fit(
135+
interactions=self.dataset.interactions,
136+
epochs=self.config.FIT_PARAMS['epochs']
137+
)
138+
139+
def build_annoy_representations(
140+
self,
141+
feature_type: str,
142+
is_cab: bool
143+
) -> None:
144+
'''
145+
Getting product/user matrix into proper representations required to
146+
perform Approximate Nearest Neighbors.
147+
---------------
148+
From LightFM get_item_representations - Index 0: Item biases; Index 1: Item embeddings
149+
- Excluding Content-based information for CAB model since user/item features overpower MF.
150+
'''
151+
logging.info('Preparing matrix representations for ANN ~')
152+
if is_cab:
153+
file_label = '_cab.ann'
154+
emb_dim = self.config.ANNOY_PARAMS['cab_emb_dim']
155+
if feature_type == 'user':
156+
latent_repr_emb = self.cab_model.get_user_representations()[1]
157+
logging.info(
158+
f'Preparing CAB Annoy object using user_features\n'
159+
f'Type: {type(self.dataset.user_features)}\n'
160+
f'Shape: {self.dataset.user_features.shape}'
161+
)
162+
elif feature_type == 'item':
163+
latent_repr_emb = self.cab_model.get_item_representations()[1]
164+
logging.info(
165+
f'Preparing CAB Annoy object using item_features\n'
166+
f'Type: {type(self.dataset.item_features)}\n'
167+
f'Shape: {self.dataset.item_features.shape}'
168+
)
169+
else:
170+
raise ValueError('Unknown feature type passed to function')
171+
else:
172+
file_label = '.ann'
173+
emb_dim = self.config.ANNOY_PARAMS['emb_dim']
174+
if feature_type == 'user':
175+
latent_repr_emb = self.model.get_user_representations(
176+
features=self.dataset.user_features
177+
)[1]
178+
logging.info(
179+
f'Preparing Annoy object using user_features\n'
180+
f'Type: {type(self.dataset.user_features)}\n'
181+
f'Shape: {self.dataset.user_features.shape}'
182+
)
183+
elif feature_type == 'item':
184+
latent_repr_emb = self.model.get_item_representations(
185+
features=self.dataset.item_features
186+
)[1]
187+
logging.info(
188+
f'Preparing Annoy object using item_features\n'
189+
f'Type: {type(self.dataset.item_features)}\n'
190+
f'Shape: {self.dataset.item_features.shape}'
191+
)
192+
else:
193+
raise ValueError('Unknown feature type passed to function')
194+
195+
logging.info(f'Shape of embeddings: {latent_repr_emb.shape}')
196+
a = AnnoyIndex(emb_dim, metric=self.config.ANNOY_PARAMS['metric'])
197+
for item in range(len(latent_repr_emb)):
198+
a.add_item(item, latent_repr_emb[item])
199+
a.build(self.config.ANNOY_PARAMS['trees'])
200+
201+
persistence_path = Path(self.persistence_path)
202+
if not persistence_path.is_dir():
203+
persistence_path.mkdir(parents=True)
204+
205+
persistance_file = os.path.join(self.persistence_path, feature_type + file_label)
206+
a.save(persistance_file)
207+
208+
logging.info(
209+
f'{feature_type} vector representations saved to {persistance_file}'
210+
)
211+
212+
def delete_var(self, classname: str, attrname: str) -> None:
213+
'''
214+
Delete unused variables that have been assigned.
215+
'''
216+
if classname == 'model':
217+
delattr(self, attrname)
218+
elif classname == 'dataset':
219+
delattr(self.dataset, attrname)
220+
else:
221+
raise ValueError('Invalid value passed to classname argument. Value must be "dataset" or "model".')
222+
223+
def clean_up(self) -> None:
224+
'''
225+
Delete unused repr from LightFM which exhausts available memory
226+
'''
227+
self.delete_var(classname='model', attrname='model')
228+
229+
230+
@click.command()
231+
@click.option('--input_file', default=None, type=click.Path(exists=True, dir_okay=False))
232+
@click.option('--config', default='production')
233+
def main(input_file: str, config: str) -> None:
234+
logging.info("Creating model...")
235+
236+
configuration = helpers.get_configuration(config, model_configurations)
237+
model = Model(config=configuration, input_file=input_file)
238+
239+
model.build_model()
240+
model.build_cab_model()
241+
model.build_annoy_representations(feature_type='item', is_cab=True)
242+
model.build_annoy_representations(feature_type='item', is_cab=False)
243+
model.clean_up()
244+
245+
try:
246+
helpers.save(model)
247+
except Exception:
248+
logging.info('Error while saving model.')
249+
250+
251+
if __name__ == '__main__':
252+
logger = helpers.get_logger()
253+
254+
main()

0 commit comments

Comments
 (0)