Skip to content

Commit

Permalink
chore: add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
stillmatic committed Aug 21, 2024
1 parent da9b558 commit 408be30
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 2 deletions.
27 changes: 27 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package arboreal_test

import (
"github.com/stillmatic/arboreal"
"github.com/stretchr/testify/assert"
"testing"
)

func TestXGBoostSchema_Predict(t *testing.T) {
schema, err := arboreal.NewGBDTFromXGBoostJSON("testdata/regression.json")
assert.NoError(t, err)
inpArr := []float32{0.1, 0.2, 0.3, 0.4, 0.5}
inpVec := arboreal.SparseVectorFromArray(inpArr)
res, err := schema.Predict(inpVec)
assert.NoError(t, err)
assert.InDelta(t, 6.245257, res[0], 0.000001)
}

func TestOptimizedGBTModel(t *testing.T) {
res, err := arboreal.NewGBDTFromXGBoostJSON("testdata/mortgage_xgb.json")
assert.NoError(t, err)
newRes := arboreal.NewOptimizedGBDTClassifierFromSchema(res)

vec := make(arboreal.SparseVector, 44)
_, err = newRes.Predict(vec)
assert.NoError(t, err)
}
4 changes: 4 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Pure Go library for gradient boosted decision trees. The library is optimized fo

# Usage

```bash
go get github.com/stillmatic/arboreal
```

```go
res, err := arboreal.NewGBDTFromXGBoostJSON("testdata/regression.json")
inpArr := []float32{0.1, 0.2, 0.3, 0.4, 0.5}
Expand Down
1 change: 1 addition & 0 deletions xgboost.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Package arboreal is a pure Go package for XGBoost model inference.
package arboreal

import (
Expand Down
5 changes: 3 additions & 2 deletions xgboostio.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package arboreal

// IO for XGBoost JSON files
// see https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
// https://github.com/dmlc/xgboost/blob/24c237308097b693b744af2ad1f86f44be068523/demo/json-model/json_parser.py
package arboreal

import (
"encoding/json"
Expand All @@ -12,7 +13,7 @@ import (
"github.com/pkg/errors"
)

// custom JSON unmarshal for learner
// UnmarshalJSON is a custom JSON unmarshal for learner
func (l *learner) UnmarshalJSON(b []byte) error {
var tmp struct {
FeatureNames []featureName `json:"feature_names,omitempty"`
Expand Down

0 comments on commit 408be30

Please sign in to comment.