-
Notifications
You must be signed in to change notification settings - Fork 13
Update build_tree function with SparseKmeans implementation #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
|
||
import numpy as np | ||
import scipy.sparse as sparse | ||
import sklearn.cluster | ||
from sparsekmeans import LloydKmeans, ElkanKmeans | ||
import sklearn.preprocessing | ||
from tqdm import tqdm | ||
import psutil | ||
|
@@ -277,24 +277,37 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, | |
if d >= dmax or label_representation.shape[0] <= K: | ||
return Node(label_map=label_map, children=[]) | ||
|
||
metalabels = ( | ||
sklearn.cluster.KMeans( | ||
K, | ||
random_state=np.random.randint(2**31 - 1), | ||
n_init=1, | ||
max_iter=300, | ||
tol=0.0001, | ||
algorithm="elkan", | ||
) | ||
.fit(label_representation) | ||
.labels_ | ||
) | ||
if label_representation.shape[0] > 10000: | ||
kmeans = ElkanKmeans( | ||
n_clusters=K, | ||
max_iter=300, | ||
tol=0.0001, | ||
random_state=np.random.randint(2**31 - 1), | ||
verbose=True | ||
) | ||
else: | ||
kmeans = LloydKmeans( | ||
n_clusters=K, | ||
max_iter=300, | ||
tol=0.0001, | ||
random_state=np.random.randint(2**31 - 1), | ||
verbose=True | ||
) | ||
|
||
metalabels = kmeans.fit(label_representation) | ||
|
||
unique_labels = np.unique(metalabels) | ||
|
||
children = [] | ||
for i in range(K): | ||
child_representation = label_representation[metalabels == i] | ||
child_map = label_map[metalabels == i] | ||
child = _build_tree(child_representation, child_map, d + 1, K, dmax) | ||
|
||
if len(unique_labels) == K: | ||
child = _build_tree(child_representation, child_map, d + 1, K, dmax) | ||
else: | ||
child = Node(label_map=child_map, children=[]) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's better to have num_unique_labels = len(np.unique(metalabels))
if len(num_unique_labels) == K:
children = []
for i in range(K):
child_representation = label_representation[metalabels == i]
child_map = label_map[metalabels == i]
child = _build_tree(child_representation, child_map, d + 1, K, dmax)
children.append(child)
else:
children = [
Node(label_map=label_map[metalabels == i], children=[])
for i in range(num_unique_labels)
] |
||
children.append(child) | ||
|
||
return Node(label_map=label_map, children=children) | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,3 +6,4 @@ scikit-learn | |||||||||||||||||||||
scipy<1.14.0 | ||||||||||||||||||||||
tqdm | ||||||||||||||||||||||
psutil | ||||||||||||||||||||||
sparsekmeans | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please also add sparsekmeans to Lines 27 to 35 in a0bef91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Line 3 in a0bef91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sparsekmeans requires Python >= 3.10, whereas LibMultiLabel supports Python >= 3.8.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @khoinpd0411 No need to bump version now, we will release with #20. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just wondering why the indentation isn't aligned with
line:287
.BTW, should we pass the
verbose
flag through_build_tree()
so that we can control the output when training?And would it be better to do something like (I'm not sure)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the formatting issue, please use black formatter.