Skip to content

feat(linear): Add ensemble tree model and solver-aware scoring #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

shenkha
Copy link

@shenkha shenkha commented Jul 14, 2025

What does this PR do?

(Some descriptions here...)
This pull request introduces two major enhancements to the linear tree-based models:

  1. Ensemble Tree Model: Implements an ensemble of tree models to improve prediction accuracy and robustness over a single tree.
  2. Solver-Aware Scoring: Fixes a critical bug in the beam search scoring logic. The logic now correctly calculates path probabilities based on whether an SVM or a Logistic Regression solver is used.

Key Changes:

1. Ensemble of Trees

  • A new EnsembleTreeModel class in libmultilabel/linear/tree.py now manages multiple tree models.
  • The train_ensemble_tree function handles the training of n separate tree models, each with a different random seed for diversity.
  • The ensemble's final predictions are an average of the scores from each tree, providing a more stable and accurate result.
  • This functionality is exposed via a new CLI argument --tree_ensemble_models in main.py and integrated into linear_trainer.py.

Example usage:

python main.py --training_file data/eurlex_raw_texts_train.txt \
                --test_file data/eurlex_raw_texts_test.txt \
                --linear \
                --linear_technique tree \
                --tree_ensemble_models 3

2. Corrected Scoring Logic

  • The _is_lr method in TreeModel now correctly identifies all of LIBLINEAR's Logistic Regression solvers (0, 6, and 7).
  • The _get_scores method has been updated to use the correct scoring function based on the solver type:
    • For Logistic Regression, it now uses log_expit to correctly accumulate log-probabilities along a path in the tree.
    • For SVM-based solvers, it continues to use the existing calculation based on squared hinge loss.
      This fix is crucial for the beam search to find the optimal labels, as the previous implementation incorrectly applied the SVM scoring logic to LR models.

Test CLI & API (bash tests/autotest.sh)

Test APIs used by main.py.

  • Test Pass
    • (Copy and paste the last outputted line here.)
  • Not Applicable (i.e., the PR does not include API changes.)

Check API Document

If any new APIs are added, please check if the description of the APIs is added to API document.

  • API document is updated (linear, nn)
  • Not Applicable (i.e., the PR does not include API changes.)

Test quickstart & API (bash tests/docs/test_changed_document.sh)

If any APIs in quickstarts or tutorials are modified, please run this test to check if the current examples can run correctly after the modified APIs are released.

@shenkha shenkha requested review from cjlin1 and a team as code owners July 14, 2025 14:45
next_level.extend(zip(node.children, children_score.tolist()))

cur_level = sorted(next_level, key=lambda pair: -pair[1])[:beam_width]
next_level = []

num_labels = len(self.root.label_map)
scores = np.zeros(num_labels)
scores = np.full(num_labels, 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need to modify this line?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my mistake, I have just checked and will revert right away

return solver_type in ["0", "6", "7"]
return False

def _get_scores(self, pred, parent_score=0.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should specify the parameter type. Please see other functions.

@Eleven1Liu Eleven1Liu self-requested a review July 17, 2025 02:13
Copy link
Contributor

@Eleven1Liu Eleven1Liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the formatting issues mentioned above, please use black formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants