diff --git a/implicit/evaluation.pyx b/implicit/evaluation.pyx index f25d0da..830e9c4 100644 --- a/implicit/evaluation.pyx +++ b/implicit/evaluation.pyx @@ -76,7 +76,7 @@ cdef _choose(rng, int n, float frac): return arr -cdef _take_tails(arr, int n, return_complement=False, shuffled=False): +cdef _take_tails(arr, int n, rng, return_complement=False, shuffled=False): """ Given an array of (optionally shuffled) integers in the range 0->n, take the indices of the last 'n' occurrences of each integer (tail) -- subject to shuffling. @@ -127,7 +127,7 @@ cdef _take_tails(arr, int n, return_complement=False, shuffled=False): ranges = np.linspace(start, end, num=n + 1, dtype=int)[1:] if shuffled: - shuffled_idx = (sorted_arr + np.random.random(arr.shape)).argsort() + shuffled_idx = (sorted_arr + rng.random(arr.shape)).argsort() tails = shuffled_idx[np.ravel(ranges, order="f")] else: tails = np.ravel(ranges, order="f") @@ -206,7 +206,7 @@ cpdef leave_k_out_split( candidate_data = data[full_candidate_mask] test_idx, train_idx = _take_tails( - candidate_users, K, shuffled=True, return_complement=True + candidate_users, K, random_state, shuffled=True, return_complement=True ) # get all remaining remaining candidate user-item pairs, and prepare to append to