-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrsvd_train
executable file
·122 lines (107 loc) · 3.98 KB
/
rsvd_train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/python
"""A script to train a regularized SVD model
on a given numpy record array. The data type is assumed to
be rsvd.rating_t. A struct containing an uint16, an uint32 and an uint8.
"""
import sys
import getopt
import numpy as np
import rsvd
__version__ = rsvd.__version__
__author__ = rsvd.__author__
__license__ = rsvd.__license__
class Usage(Exception):
def __init__(self, msg):
self.msg = msg
def usage():
print """Usage: rsvd_train [options] training_array num_movies num_users output_dir
Trains a regularized SVD solver on the given training data and stores the
trained model in `output_dir`. The training data is assumed to be
a serialized numpy record array. `num_movies` and `num_users` are the number
of movies and users, resp., in the data set.
See <http://code.google.com/p/pyrsvd/> for further information.
Options:
-h, --help\tPrint this help
-f <int>\tThe number of latent factors to compute.
-l <float>\tLearn rate [default: 0.001]
-r <float>\tRegularization parameter [default: 0.11]
--probe <file>\tEstimate error on probe set.
\t\t<file> contains the probe set in a numpy record array.
\t\tIf defined, early stopping is turned on
--maxepochs\tThe max number of epochs to perform
--randomize\tShuffle the training data every 10th epoch.
--minimprovement\tThe min improvement in RMSE on the probeset to trigger early stopping [default: 0.000001].
For bug reporting, please mail to:
"""
def main(argv=None):
if argv is None:
argv = sys.argv
try:
try:
opts, args = getopt.getopt(argv[1:], "hl:r:f:", \
["help","probe=","maxepochs=",\
"minimprovement=","randomize"])
except getopt.error, msg:
raise Usage(msg)
lr=0.001
reg=0.011
probeFile=None
ratingsFile=None
max_epochs=100
randomize=False
factors=10
min_improvement=0.000001
for o, a in opts:
if o in ("-h", "--help"):
usage()
sys.exit()
elif o in ("--probe"):
probeFile=a
elif o in ("-l"):
lr=float(a)
elif o in ("-r"):
reg=float(a)
elif o in ("-f"):
factors=int(a)
elif o in ("--maxepochs"):
max_epochs=int(a)
elif o in ("--randomize"):
randomize=True
elif o in ("--minimprovement"):
min_improvement=float(a)
else:
raise Usage("unhandled option")
if len(args) < 4:
raise Usage("missing arguments")
try:
ratingsFile=open(args[0],'r')
nmovies=int(args[1])
nusers=int(args[2])
output=args[3]
print >>sys.stdout, "Loading training data...\t",
sys.stdout.flush()
ratingsArray=np.fromfile(ratingsFile,dtype=rsvd.rating_t)
print >>sys.stdout, "done."
probeArray=None
if probeFile:
print >>sys.stdout, "Loading probe data...\t",
sys.stdout.flush()
probeArray=np.fromfile(file(probeFile),dtype=rsvd.rating_t)
print >>sys.stdout, "done."
model=rsvd.RSVD.train(factors,ratingsArray,(nmovies,nusers),\
probeArray=probeArray,\
maxEpochs=max_epochs,\
minImprovement=min_improvement,\
learnRate=lr,\
regularization=reg,\
randomize=randomize)
model.save(output)
except Exception, e:
print >>sys.stderr, "Error: ",e
except Usage, err:
print >>sys.stderr, "Error: ",err.msg
usage()
return 2
if __name__ == "__main__":
sys.exit(main())