diff --git a/implicit/evaluation.pyx b/implicit/evaluation.pyx index f25d0da..4f19c41 100644 --- a/implicit/evaluation.pyx +++ b/implicit/evaluation.pyx @@ -205,8 +205,13 @@ cpdef leave_k_out_split( candidate_items = items[full_candidate_mask] candidate_data = data[full_candidate_mask] + # reindex candidate_user indices so they are properly formatted for the + # calculations in _take_tails + xsorted = np.argsort(unique_candidate_users) + reindexed_candidate_users = np.searchsorted(unique_candidate_users[xsorted], candidate_users) + test_idx, train_idx = _take_tails( - candidate_users, K, shuffled=True, return_complement=True + reindexed_candidate_users, K, shuffled=True, return_complement=True ) # get all remaining remaining candidate user-item pairs, and prepare to append to diff --git a/tests/evaluation_test.py b/tests/evaluation_test.py index df02261..8460150 100644 --- a/tests/evaluation_test.py +++ b/tests/evaluation_test.py @@ -15,7 +15,22 @@ def _get_sample_matrix(): def _get_matrix(): - mat = random(100, 100, density=0.5, format="csr", dtype=np.float32) + mat = random(100, 100, density=0.1, format="csr", dtype=np.float32) + return mat.tocoo() + + +def _get_fixed_matrix(): + mat = csr_matrix( + [ + [1, 0, 0, 0], + [3, 2, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 1, 1, 1], + [0, 0, 1, 0], + ] + ) return mat.tocoo() @@ -48,6 +63,10 @@ def test_leave_k_out_outputs_produce_input(): train, test = leave_k_out_split(mat, K=1) assert ((train + test) - mat).nnz == 0 + mat = _get_fixed_matrix() + train, test = leave_k_out_split(mat, K=1) + assert ((train + test) - mat).nnz == 0 + def test_leave_k_split_is_reservable(): """