Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions pyfm/pylibfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class FM:
Whether or not to shuffle training dataset before learning
seed : int
The seed of the pseudo random number generator
reg_0 : float
The regularization parameter of w0
reg_w : float
The regularization parameter of w
reg_v : float
The regularization parameter of each element in v
"""
def __init__(self,
num_factors=10,
Expand All @@ -63,7 +69,11 @@ def __init__(self,
task='classification',
verbose=True,
shuffle_training=True,
seed = 28):
seed = 28,
reg_0 = 0.0,
reg_w = 0.0,
reg_v = 0.0,
):

self.num_factors = num_factors
self.num_iter = num_iter
Expand All @@ -86,9 +96,9 @@ def __init__(self,
self.t0 = t0

# Regularization Parameters (start with no regularization)
self.reg_0 = 0.0
self.reg_w = 0.0
self.reg_v = np.repeat(0.0, num_factors)
self.reg_0 = float(reg_0)
self.reg_w = float(reg_w)
self.reg_v = np.repeat(float(reg_v), num_factors)

# local parameters in the lambda_update step
self.lambda_w_grad = 0.0
Expand Down Expand Up @@ -205,7 +215,11 @@ def fit(self, X, y):
shuffle_training,
task,
self.seed,
verbose)
verbose,
self.reg_0,
self.reg_w,
self.reg_v,
)

return self.fm_fast.fit(X_train_dataset, validation_dataset)

Expand Down
15 changes: 11 additions & 4 deletions pyfm_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ cdef class FM_fast(object):
task : int
seed : int
verbose : int
reg_0 : double
reg_w : double
reg_v : np.ndarray[DOUBLE, ndim=1, mode='c']
"""

cdef public double w0
Expand Down Expand Up @@ -106,7 +109,11 @@ cdef class FM_fast(object):
int shuffle_training,
int task,
int seed,
int verbose):
int verbose,
double reg_0,
double reg_w,
np.ndarray[DOUBLE, ndim=1, mode='c'] reg_v,
):

self.w0 = w0
self.w = w
Expand All @@ -130,9 +137,9 @@ cdef class FM_fast(object):
self.seed = seed
self.verbose = verbose

self.reg_0 = 0.0
self.reg_w = 0.0
self.reg_v = np.zeros(self.num_factors)
self.reg_0 = reg_0
self.reg_w = reg_w
self.reg_v = reg_v

self.sumloss = 0.0
self.count = 0
Expand Down