From 2e47265fffe3f029433333953a14977e357d2701 Mon Sep 17 00:00:00 2001 From: takumi uchida Date: Sat, 17 Aug 2019 15:44:34 +0900 Subject: [PATCH] add interface and available regularlization --- pyfm/pylibfm.py | 24 +++++++++++++++++++----- pyfm_fast.pyx | 15 +++++++++++---- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/pyfm/pylibfm.py b/pyfm/pylibfm.py index b0b71a4..f1f1cd8 100644 --- a/pyfm/pylibfm.py +++ b/pyfm/pylibfm.py @@ -48,6 +48,12 @@ class FM: Whether or not to shuffle training dataset before learning seed : int The seed of the pseudo random number generator + reg_0 : float + The regularization parameter of w0 + reg_w : float + The regularization parameter of w + reg_v : float + The regularization parameter of each element in v """ def __init__(self, num_factors=10, @@ -63,7 +69,11 @@ def __init__(self, task='classification', verbose=True, shuffle_training=True, - seed = 28): + seed = 28, + reg_0 = 0.0, + reg_w = 0.0, + reg_v = 0.0, + ): self.num_factors = num_factors self.num_iter = num_iter @@ -86,9 +96,9 @@ def __init__(self, self.t0 = t0 # Regularization Parameters (start with no regularization) - self.reg_0 = 0.0 - self.reg_w = 0.0 - self.reg_v = np.repeat(0.0, num_factors) + self.reg_0 = float(reg_0) + self.reg_w = float(reg_w) + self.reg_v = np.repeat(float(reg_v), num_factors) # local parameters in the lambda_update step self.lambda_w_grad = 0.0 @@ -205,7 +215,11 @@ def fit(self, X, y): shuffle_training, task, self.seed, - verbose) + verbose, + self.reg_0, + self.reg_w, + self.reg_v, + ) return self.fm_fast.fit(X_train_dataset, validation_dataset) diff --git a/pyfm_fast.pyx b/pyfm_fast.pyx index 8c5c82a..0c9aad7 100644 --- a/pyfm_fast.pyx +++ b/pyfm_fast.pyx @@ -51,6 +51,9 @@ cdef class FM_fast(object): task : int seed : int verbose : int + reg_0 : double + reg_w : double + reg_v : np.ndarray[DOUBLE, ndim=1, mode='c'] """ cdef public double w0 @@ -106,7 +109,11 @@ cdef class FM_fast(object): int shuffle_training, int task, int seed, - int verbose): + int verbose, + double reg_0, + double reg_w, + np.ndarray[DOUBLE, ndim=1, mode='c'] reg_v, + ): self.w0 = w0 self.w = w @@ -130,9 +137,9 @@ cdef class FM_fast(object): self.seed = seed self.verbose = verbose - self.reg_0 = 0.0 - self.reg_w = 0.0 - self.reg_v = np.zeros(self.num_factors) + self.reg_0 = reg_0 + self.reg_w = reg_w + self.reg_v = reg_v self.sumloss = 0.0 self.count = 0