Skip to content

Adjust probability estimation function used in linear tree-based method #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import scipy.sparse as sparse
import sklearn.cluster
import sklearn.preprocessing
from scipy.special import log_expit
from tqdm import tqdm
import psutil

Expand Down Expand Up @@ -54,21 +55,30 @@ def __init__(
self.node_ptr = node_ptr
self.multiclass = False
self._model_separated = False # Indicates whether the model has been separated for pruning tree.
self.estimator_parameter = 3

def sigmoid_A(self, x):
return log_expit(self.estimator_parameter * x)

def predict_values(
self,
x: sparse.csr_matrix,
beam_width: int = 10,
estimation_parameter: int = 3,
) -> np.ndarray:
"""Calculate the probability estimates associated with x.

Args:
x (sparse.csr_matrix): A matrix with dimension number of instances * number of features.
beam_width (int, optional): Number of candidates considered during beam search. Defaults to 10.
estimation_parameter (int, optional): The tunable parameter of probability estimation function, that is sigmoid(estimation_parameter * preds).

Returns:
np.ndarray: A matrix with dimension number of instances * number of classes.
"""

self.estimator_parameter = estimation_parameter

if beam_width >= len(self.root.children):
# Beam_width is sufficiently large; pruning not applied.
# Calculates decision values for all nodes.
Expand Down Expand Up @@ -129,7 +139,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)

# Calculate root decision values and scores
root_preds = linear.predict_values(self.root_model, x)
children_scores = 0.0 - np.square(np.maximum(0, 1 - root_preds))
children_scores = 0.0 + self.sigmoid_A(root_preds)

slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]]
all_preds[slice] = root_preds
Expand Down Expand Up @@ -179,7 +189,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
continue
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
pred = instance_preds[slice]
children_score = score - np.square(np.maximum(0, 1 - pred))
children_score = score + self.sigmoid_A(pred)
next_level.extend(zip(node.children, children_score.tolist()))

cur_level = sorted(next_level, key=lambda pair: -pair[1])[:beam_width]
Expand All @@ -190,7 +200,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
for node, score in cur_level:
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
pred = instance_preds[slice]
scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred)))
scores[node.label_map] = np.exp(score + self.sigmoid_A(pred))
return scores


Expand Down
1 change: 1 addition & 0 deletions linear_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def linear_test(config, model, datasets, label_mapping):
predict_kwargs = {}
if model.name == "tree":
predict_kwargs["beam_width"] = config.beam_width
predict_kwargs["estimation_parameter"] = config.estimation_parameter

for i in tqdm(range(ceil(num_instance / config.eval_batch_size))):
slice = np.s_[i * config.eval_batch_size : (i + 1) * config.eval_batch_size]
Expand Down
8 changes: 8 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ def add_all_arguments(parser):
default=10,
help="The width of the beam search (default: %(default)s)",
)

parser.add_argument(
"--estimation_parameter",
type=float,
default=3,
help="The parameter for probability estimation function (default: %(default)s)"
)

# AttentionXML
parser.add_argument(
"--cluster_size",
Expand Down