Skip to content

feat(linear): Add ensemble tree model and solver-aware scoring #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 105 additions & 22 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import psutil

from . import linear
from scipy.special import log_expit

__all__ = ["train_tree", "TreeModel"]

__all__ = ["train_tree", "TreeModel", "train_ensemble_tree", "EnsembleTreeModel"]


class Node:
Expand Down Expand Up @@ -47,13 +49,31 @@ def __init__(
root: Node,
flat_model: linear.FlatModel,
node_ptr: np.ndarray,
options: str,
):
self.name = "tree"
self.root = root
self.flat_model = flat_model
self.node_ptr = node_ptr
self.options = options
self.multiclass = False
self._model_separated = False # Indicates whether the model has been separated for pruning tree.
self._model_separated = False # Indicates whether the model has been separated for pruning tree.

def _is_lr(self) -> bool:
options = self.options or ""
options_split = options.split()
if "-s" in options_split:
i = options_split.index("-s")
if i + 1 < len(options_split):
solver_type = options_split[i + 1]
return solver_type in ["0", "6", "7"]
return False

def _get_scores(self, pred, parent_score=0.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should specify the parameter type. Please see other functions.

if self._is_lr():
return parent_score + log_expit(pred)
else:
return parent_score - np.square(np.maximum(0, 1 - pred))

def predict_values(
self,
Expand All @@ -72,44 +92,42 @@ def predict_values(
if beam_width >= len(self.root.children):
# Beam_width is sufficiently large; pruning not applied.
# Calculates decision values for all nodes.
all_preds = linear.predict_values(self.flat_model, x) # number of instances * (number of labels + total number of metalabels)
all_preds = linear.predict_values(
self.flat_model, x
) # number of instances * (number of labels + total number of metalabels)
else:
# Beam_width is small; pruning applied to reduce computation.
if not self._model_separated:
self._separate_model_for_pruning_tree()
self._model_separated = True
all_preds = self._prune_tree_and_predict_values(x, beam_width) # number of instances * (number of labels + total number of metalabels)
all_preds = self._prune_tree_and_predict_values(
x, beam_width
) # number of instances * (number of labels + total number of metalabels)
return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])])

def _separate_model_for_pruning_tree(self):
"""
This function separates the weights for the root node and its children into (K+1) FlatModel
for efficient beam search traversal in Python.
"""
tree_flat_model_params = {
'bias': self.root.model.bias,
'thresholds': 0,
'multiclass': False
}
tree_flat_model_params = {"bias": self.root.model.bias, "thresholds": 0, "multiclass": False}
slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]]
self.root_model = linear.FlatModel(
name="root-flattened-tree",
weights=self.flat_model.weights[slice].tocsr(),
**tree_flat_model_params
name="root-flattened-tree", weights=self.flat_model.weights[slice].tocsr(), **tree_flat_model_params
)

self.subtree_models = []
for i in range(len(self.root.children)):
subtree_weights_start = self.node_ptr[self.root.children[i].index]
subtree_weights_end = self.node_ptr[self.root.children[i+1].index] if i+1 < len(self.root.children) else -1
subtree_weights_end = (
self.node_ptr[self.root.children[i + 1].index] if i + 1 < len(self.root.children) else -1
)
slice = np.s_[:, subtree_weights_start:subtree_weights_end]
subtree_flatmodel = linear.FlatModel(
name="subtree-flattened-tree",
weights=self.flat_model.weights[slice].tocsr(),
**tree_flat_model_params
name="subtree-flattened-tree", weights=self.flat_model.weights[slice].tocsr(), **tree_flat_model_params
)
self.subtree_models.append(subtree_flatmodel)

def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray:
"""Calculates the selective decision values associated with instances x by evaluating only the most relevant subtrees.

Expand All @@ -129,7 +147,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)

# Calculate root decision values and scores
root_preds = linear.predict_values(self.root_model, x)
children_scores = 0.0 - np.square(np.maximum(0, 1 - root_preds))
children_scores = self._get_scores(root_preds)

slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]]
all_preds[slice] = root_preds
Expand All @@ -140,7 +158,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)
# Build a mask where mask[i, j] is True if the j-th subtree is among the top beam_width subtrees for the i-th instance
mask = np.zeros_like(children_scores, dtype=np.bool_)
np.put_along_axis(mask, top_beam_width_indices, True, axis=1)

# Calculate predictions for each subtree with its corresponding instances
for subtree_idx in range(len(self.root.children)):
subtree_model = self.subtree_models[subtree_idx]
Expand Down Expand Up @@ -179,7 +197,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
continue
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
pred = instance_preds[slice]
children_score = score - np.square(np.maximum(0, 1 - pred))
children_score = self._get_scores(pred, score)
next_level.extend(zip(node.children, children_score.tolist()))

cur_level = sorted(next_level, key=lambda pair: -pair[1])[:beam_width]
Expand All @@ -190,7 +208,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
for node, score in cur_level:
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
pred = instance_preds[slice]
scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred)))
scores[node.label_map] = np.exp(self._get_scores(pred, score))
return scores


Expand Down Expand Up @@ -258,7 +276,7 @@ def visit(node):
pbar.close()

flat_model, node_ptr = _flatten_model(root)
return TreeModel(root, flat_model, node_ptr)
return TreeModel(root, flat_model, node_ptr, options)


def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node:
Expand Down Expand Up @@ -382,3 +400,68 @@ def visit(node):
node_ptr = np.cumsum([0] + list(map(lambda w: w.shape[1], weights)))

return model, node_ptr


class EnsembleTreeModel:
"""An ensemble of tree models.
The ensemble aggregates predictions from multiple trees to improve accuracy and robustness.
"""

def __init__(self, tree_models: list[TreeModel]):
"""
Args:
tree_models (list[TreeModel]): A list of trained tree models.
"""
self.name = "ensemble-tree"
self.tree_models = tree_models
self.multiclass = False

def predict_values(self, x: sparse.csr_matrix, beam_width: int = 10) -> np.ndarray:
"""Calculates the averaged probability estimates from all trees in the ensemble.

Args:
x (sparse.csr_matrix): A matrix with dimension number of instances * number of features.
beam_width (int, optional): Number of candidates considered during beam search for each tree. Defaults to 10.

Returns:
np.ndarray: A matrix with dimension number of instances * number of classes, containing averaged scores.
"""
all_predictions = [model.predict_values(x, beam_width) for model in self.tree_models]
return np.mean(all_predictions, axis=0)


def train_ensemble_tree(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
options: str = "",
K: int = 100,
dmax: int = 10,
n_trees: int = 3,
seed: int = 42,
verbose: bool = True,
) -> EnsembleTreeModel:
"""Trains an ensemble of tree models (Parabel/Bonsai-style).
Args:
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
options (str, optional): The option string passed to liblinear. Defaults to ''.
K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
dmax (int, optional): Maximum depth of the tree. Defaults to 10.
n_trees (int, optional): Number of trees in the ensemble. Defaults to 3.
seed (int, optional): The base random seed for the ensemble. Defaults to 42.
verbose (bool, optional): Output extra progress information. Defaults to True.

Returns:
EnsembleTreeModel: An ensemble model which can be used for prediction.
"""
tree_models = []
for i in range(n_trees):
np.random.seed(seed + i)

tree_model = train_tree(y, x, options, K, dmax, verbose=False)
tree_models.append(tree_model)

if verbose:
print("Ensemble training completed.")

return EnsembleTreeModel(tree_models)
28 changes: 20 additions & 8 deletions linear_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import libmultilabel.linear as linear
from libmultilabel.common_utils import dump_log, is_multiclass_dataset
from libmultilabel.linear.tree import train_ensemble_tree
from libmultilabel.linear.utils import LINEAR_TECHNIQUES


Expand All @@ -21,7 +22,7 @@ def linear_test(config, model, datasets, label_mapping):
scores = []

predict_kwargs = {}
if model.name == "tree":
if model.name == "tree" or model.name == "ensemble-tree":
predict_kwargs["beam_width"] = config.beam_width

for i in tqdm(range(ceil(num_instance / config.eval_batch_size))):
Expand All @@ -48,13 +49,24 @@ def linear_train(datasets, config):
if multiclass:
raise ValueError("Tree model should only be used with multilabel datasets.")

model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
datasets["train"]["x"],
options=config.liblinear_options,
K=config.tree_degree,
dmax=config.tree_max_depth,
)
if config.tree_ensemble_models > 1:
model = train_ensemble_tree(
datasets["train"]["y"],
datasets["train"]["x"],
options=config.liblinear_options,
K=config.tree_degree,
dmax=config.tree_max_depth,
n_trees=config.tree_ensemble_models,
seed=config.seed if config.seed is not None else 42,
)
else:
model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
datasets["train"]["x"],
options=config.liblinear_options,
K=config.tree_degree,
dmax=config.tree_max_depth,
)
else:
model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def add_all_arguments(parser):
parser.add_argument(
"--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)"
)
parser.add_argument(
"--tree_ensemble_models", type=int, default=1, help="Number of models in the tree ensemble (default: %(default)s)"
)
parser.add_argument(
"--beam_width",
type=int,
Expand Down