-
-
Notifications
You must be signed in to change notification settings - Fork 299
Description
📝 Description
I would like to contribute a new module to the linfa-trees crate that implements the Random Forest algorithm for classification tasks. This will expand linfa-trees from single decision trees into ensemble learning, aligning closely with scikit-learn's functionality in Python.
🚀 Motivation
Random Forests are a powerful ensemble learning method used widely in classification tasks. They provide:
-
Robustness to overfitting
-
Better generalization than single trees
-
Feature importance estimates
Currently, linfa-trees provides support for single decision trees. By adding Random Forests, we unlock ensemble learning for the Rust ML ecosystem.
📐 Proposed Design
🔹 New Module
A new file will be added:
bashCopyEditlinfa-trees/src/decision_trees/random_forest.rs
This will include:
-
RandomForestClassifier<F: Float> -
RandomForestParams<F>(unchecked) -
RandomForestValidParams<F>(checked)
🔹 Trait Implementations
I will implement the following traits according to linfa conventions:
-
ParamGuardfor parameter validation -
Fitto train the forest using bootstrapped data and random feature subsetting -
PredictInplaceandPredictto perform inference via majority voting
🔹 Example
An example will be added in:
bashCopyEditlinfa-trees/examples/iris_random_forest.rs
Using the Iris dataset from linfa-datasets.
🔹 Benchmark (Optional)
If approved, I can also add a benchmark using Criterion:
bashCopyEditlinfa-trees/benches/random_forest.rs
📁 File Integration Plan
-
src/lib.rs: Re-exportrandom_forest::* -
src/decision_trees/mod.rs:pub mod random_forest; -
README.md: Update with a section on Random Forests and example usage -
examples/iris_random_forest.rs: Demonstrates training and evaluation
📦 API Preview
rustCopyEditlet model = RandomForest::params() .n_trees(100) .feature_subsample(0.8) .max_depth(Some(10)) .fit(&dataset)?;
let predictions = model.predict(&dataset);
let acc = predictions.confusion_matrix(&dataset)?.accuracy();
✅ Conformity with CONTRIBUTING.md
-
Uses
Floattrait forf32/f64compatibility -
Follows the
Params→ValidParamsvalidation pattern -
Implements
Fit,Predict, andPredictInplaceusingDataset -
Optional
serdesupport via feature flag -
Will include unit tests and optionally benchmarks
🙋♂️ Request
Please let me know if you're open to this contribution. I’d be happy to align with maintainers on:
-
Feature scope (classifier first, regressor later?)
-
Benchmarking standards
-
Integration strategy (e.g., reuse of
DecisionTree)
Looking forward to your guidance!