Skip to content

Commit

Permalink
feat: clean up pointer interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
stillmatic committed Aug 21, 2024
1 parent 33d0199 commit da9b558
Show file tree
Hide file tree
Showing 10 changed files with 662 additions and 56 deletions.
8 changes: 8 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@ require (
)

require (
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect
github.com/alecthomas/units v0.0.0-20240626203959-61d1e3462e30 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gedex/inflector v0.0.0-20170307190818-16278e9db813 // indirect
github.com/idubinskiy/schematyper v0.0.0-20190118213059-f71b40dac30d // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/viterin/partial v1.1.0 // indirect
github.com/viterin/vek v0.4.2 // indirect
golang.org/x/sys v0.24.0 // indirect
gopkg.in/alecthomas/kingpin.v2 v2.2.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
25 changes: 25 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20240626203959-61d1e3462e30 h1:t3eaIm0rUkzbrIewtiFmMK5RXHej2XnoXNhxVsAYUfg=
github.com/alecthomas/units v0.0.0-20240626203959-61d1e3462e30/go.mod h1:fvzegU4vN3H1qMT+8wDmzjAcDONcgo2/SZ/TyfdUOFs=
github.com/chewxy/math32 v1.11.0 h1:8sek2JWqeaKkVnHa7bPVqCEOUPbARo4SGxs6toKyAOo=
github.com/chewxy/math32 v1.11.0/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gedex/inflector v0.0.0-20170307190818-16278e9db813 h1:Uc+IZ7gYqAf/rSGFplbWBSHaGolEQlNLgMgSE3ccnIQ=
github.com/gedex/inflector v0.0.0-20170307190818-16278e9db813/go.mod h1:P+oSoE9yhSRvsmYyZsshflcR6ePWYLql6UU1amW13IM=
github.com/idubinskiy/schematyper v0.0.0-20190118213059-f71b40dac30d h1:sQbbvtUoen3Tfl9G/079tXeqniwPH6TgM/lU4y7lQN8=
github.com/idubinskiy/schematyper v0.0.0-20190118213059-f71b40dac30d/go.mod h1:xVHEhsiSJJnT0jlcQpQUg+GyoLf0i0xciM1kqWTGT58=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/viterin/partial v1.1.0 h1:iH1l1xqBlapXsYzADS1dcbizg3iQUKTU1rbwkHv/80E=
github.com/viterin/partial v1.1.0/go.mod h1:oKGAo7/wylWkJTLrWX8n+f4aDPtQMQ6VG4dd2qur5QA=
github.com/viterin/vek v0.4.2 h1:Vyv04UjQT6gcjEFX82AS9ocgNbAJqsHviheIBdPlv5U=
github.com/viterin/vek v0.4.2/go.mod h1:A4JRAe8OvbhdzBL5ofzjBS0J29FyUrf95tQogvtHHUc=
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI=
golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
38 changes: 38 additions & 0 deletions math_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package arboreal_test

import (
"github.com/viterin/vek/vek32"
"testing"

math "github.com/chewxy/math32"
Expand Down Expand Up @@ -28,6 +29,27 @@ func softmax(ys []float32) []float32 {
return output
}

func softmaxSimd(vector []float32) []float32 {
r := make([]float32, len(vector))
vek32.Exp_Into(r, vector)
sum := vek32.Sum(r)
if sum != float32(0.0) {
inverseSum := float32(1.0) / sum
vek32.MulNumber_Inplace(r, inverseSum)
}
return r
}

func softmaxSimdInplace(vector []float32) []float32 {
vek32.Exp_Inplace(vector)
sum := vek32.Sum(vector)
if sum != float32(0.0) {
inverseSum := float32(1.0) / sum
vek32.MulNumber_Inplace(vector, inverseSum)
}
return vector
}

func softmaxAlt(vector []float32) []float32 {
sum := float32(0.0)
r := make([]float32, len(vector))
Expand Down Expand Up @@ -76,8 +98,14 @@ func (s *sigmoidTable) sigmoid(x float32) float32 {
return s.expTable[int((x+s.maxExp)*s.cache)]
}

// inplace SIMD saves an alloc and a couple nanoseconds but it's not a big difference.
// BenchmarkSoftmax/softmax-10 10690101 107.1 ns/op 48 B/op 1 allocs/op
// BenchmarkSoftmax/softmaxAlt-10 11185256 107.5 ns/op 48 B/op 1 allocs/op
// BenchmarkSoftmax/SIMD-10 10492280 113.2 ns/op 48 B/op 1 allocs/op
// BenchmarkSoftmax/SIMD_Inplace-10 11581250 103.2 ns/op 0 B/op 0 allocs/op
func BenchmarkSoftmax(b *testing.B) {
vector := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
b.ResetTimer()
b.Run("softmax", func(b *testing.B) {
for i := 0; i < b.N; i++ {
arboreal.Softmax(vector)
Expand All @@ -88,6 +116,16 @@ func BenchmarkSoftmax(b *testing.B) {
softmaxAlt(vector)
}
})
b.Run("SIMD", func(b *testing.B) {
for i := 0; i < b.N; i++ {
softmaxSimd(vector)
}
})
b.Run("SIMD_Inplace", func(b *testing.B) {
for i := 0; i < b.N; i++ {
softmaxSimdInplace(vector)
}
})
}

func BenchmarkSigmoid(b *testing.B) {
Expand Down
13 changes: 6 additions & 7 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@ package arboreal

import (
"encoding/json"
"io/ioutil"

"github.com/pkg/errors"
"os"
)

func NewGBDTFromXGBoostJSON(filename string) (xgboostSchema, error) {
var schema xgboostSchema
jsonIO, err := ioutil.ReadFile(filename)
func NewGBDTFromXGBoostJSON(filename string) (*XGBoostSchema, error) {
var schema *XGBoostSchema
jsonIO, err := os.ReadFile(filename)
if err != nil {
return schema, errors.Wrapf(err, "failed to open %s", filename)
return nil, errors.Wrapf(err, "failed to open %s", filename)
}
err = json.Unmarshal(jsonIO, &schema)
if err != nil {
return schema, errors.Wrapf(err, "couldn't unmarshal json")
return nil, errors.Wrapf(err, "couldn't unmarshal json")
}
return schema, nil
}
10 changes: 5 additions & 5 deletions xgboost.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

type GradientBooster interface {
GetName() string
Predict(features *SparseVector) ([]float32, error)
Predict(features SparseVector) ([]float32, error)
}

type GBLinear struct {
Expand All @@ -18,7 +18,7 @@ type GBLinear struct {
} `json:"model"`
}

func (m *GBLinear) Predict(features *SparseVector) ([]float32, error) {
func (m *GBLinear) Predict(features SparseVector) ([]float32, error) {
var result []float32
return result, errors.New("not yet implemented")
}
Expand All @@ -27,7 +27,7 @@ func (m *GBLinear) GetName() string {
return m.Name
}

func (m *GBTree) Predict(features *SparseVector) ([]float32, error) {
func (m *GBTree) Predict(features SparseVector) ([]float32, error) {
result := make([]float32, len(m.Model.Trees))

for idx, tree := range m.Model.Trees {
Expand All @@ -44,11 +44,11 @@ func (m *GBTree) GetName() string {
return m.Name
}

func (t *tree) Predict(features *SparseVector) (float32, error) {
func (t *tree) Predict(features SparseVector) (float32, error) {
return 0.0, nil
}

func (m *xgboostSchema) Predict(features *SparseVector) ([]float32, error) {
func (m *XGBoostSchema) Predict(features SparseVector) ([]float32, error) {
internalResults, err := m.Learner.GradientBooster.Predict(features)
if err != nil {
return nil, errors.Wrap(err, "failed to predict with gradient booster")
Expand Down
25 changes: 8 additions & 17 deletions xgboost_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@ type aftLossParam struct {
AftLossDistributionScale string `json:"aft_loss_distribution_scale,omitempty"`
}

type categoricalSize int

type categoriesNode int

type categoriesSegment int

type category int

type featureName string

type featureType string
Expand All @@ -26,9 +18,8 @@ type GBTree struct {
}

type gbtreeModelParam struct {
NumTrees string `json:"num_trees"`
// NumParallelTree string `json:"num_parallel_tree,omitempty"`
// SizeLeafVector string `json:"size_leaf_vector"`
NumTrees string `json:"num_trees"`
NumParallelTree string `json:"num_parallel_tree"`
}

type lambdaRankParam struct {
Expand All @@ -48,6 +39,7 @@ type learnerModelParam struct {
BaseScore float32 `json:"base_score,omitempty"`
NumClass int `json:"num_class,omitempty"`
NumFeature int `json:"num_feature,omitempty"`
NumTarget int `json:"num_target,omitempty"`
}

type model struct {
Expand All @@ -70,7 +62,7 @@ type softmaxMulticlassParam struct {

type tree struct {
// BaseWeights []float32 `json:"base_weights"`
CategoricalSizes []int `json:"categorical_sizes,omitempty"`
CategoriesSizes []int `json:"categories_sizes,omitempty"`
Categories []int `json:"categories"`
CategoriesNodes []int `json:"categories_nodes"`
CategoriesSegments []int `json:"categories_segments"`
Expand All @@ -97,13 +89,12 @@ type treeTreeParam struct {
SizeLeafVector string `json:"size_leaf_vector"`
}

type xgboostSchema struct {
type XGBoostSchema struct {
Learner *learner `json:"learner"`
Version []int `json:"version"`
}

type xgboostSchemaTreeParam struct {
NumFeature string `json:"num_feature"`
NumNodes string `json:"num_nodes"`
SizeLeafVector string `json:"size_leaf_vector"`
type XGBoostSchemaTreeParam struct {
NumFeature string `json:"num_feature"`
NumNodes string `json:"num_nodes"`
}
20 changes: 10 additions & 10 deletions xgboost_optimized.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ type OptimizedGBDTClassifier struct {
NumClasses int
}

func NewOptimizedGBDTClassifierFromSchema(model *xgboostSchema) OptimizedGBDTClassifier {
func NewOptimizedGBDTClassifierFromSchema(model *XGBoostSchema) OptimizedGBDTClassifier {
origModel := model.Learner.GradientBooster.(*GBTModelOptimized)
return OptimizedGBDTClassifier{
Model: origModel,
Expand All @@ -27,12 +27,12 @@ func sigmoidSingleOpt(x float32) float32 {
return 1.0 / (1.0 + math.Exp(-x))
}

func (m *OptimizedGBDTClassifier) Predict(features *SparseVector) ([]float32, error) {
func (m *OptimizedGBDTClassifier) Predict(features SparseVector) ([]float32, error) {
numClasses := max(m.NumClasses, 1)
treesPerClass := len(m.Model.Trees) / numClasses
perClassScore := make([]float32, numClasses)
for i := 0; i < numClasses; i++ {
offset := (i * treesPerClass)
offset := i * treesPerClass
for j := 0; j < treesPerClass; j++ {
perClassScore[i] += m.Model.Trees[offset+j].Predict(features)
}
Expand All @@ -54,7 +54,7 @@ func (m *OptimizedGBDTClassifier) PredictFloats(features []float32) ([]float32,
for i := 0; i < numClasses; i++ {
offset := (i * treesPerClass)
for j := 0; j < treesPerClass; j++ {
perClassScore[i] += m.Model.Trees[offset+j].Predict(&sv)
perClassScore[i] += m.Model.Trees[offset+j].Predict(sv)
}
perClassScore[i] = sigmoidSingleOpt(perClassScore[i])
}
Expand Down Expand Up @@ -88,7 +88,7 @@ func (m *GBTModelOptimized) GetName() string {
return "gbtree_optimized"
}

func (m *GBTModelOptimized) Predict(features *SparseVector) ([]float32, error) {
func (m *GBTModelOptimized) Predict(features SparseVector) ([]float32, error) {
result := make([]float32, len(m.Trees))

for idx, tree := range m.Trees {
Expand All @@ -98,7 +98,7 @@ func (m *GBTModelOptimized) Predict(features *SparseVector) ([]float32, error) {
return result, nil
}

func (t *TreeOptimized) predictCategorical(features *SparseVector) float32 {
func (t *TreeOptimized) predictCategorical(features SparseVector) float32 {
idx := 0

for {
Expand All @@ -116,7 +116,7 @@ func (t *TreeOptimized) predictCategorical(features *SparseVector) float32 {

splitCol := node.SplitIndex
// splitVal := node.SplitCondition
fval, ok := (*features)[splitCol]
fval, ok := features[splitCol]

// missing value behavior is determined by default left
if !ok {
Expand All @@ -136,7 +136,7 @@ func (t *TreeOptimized) predictCategorical(features *SparseVector) float32 {
}
}

func (t *TreeOptimized) predictNumerical(features *SparseVector) float32 {
func (t *TreeOptimized) predictNumerical(features SparseVector) float32 {
idx := 0
for {
node := t.Nodes[idx]
Expand All @@ -153,7 +153,7 @@ func (t *TreeOptimized) predictNumerical(features *SparseVector) float32 {

splitCol := node.SplitIndex
splitVal := node.SplitCondition
fval, ok := (*features)[splitCol]
fval, ok := features[splitCol]

// missing value behavior is determined by default left
if !ok {
Expand All @@ -173,7 +173,7 @@ func (t *TreeOptimized) predictNumerical(features *SparseVector) float32 {
}
}

func (t *TreeOptimized) Predict(features *SparseVector) float32 {
func (t *TreeOptimized) Predict(features SparseVector) float32 {
if t.Nodes[0].SplitType == 1 {
return t.predictCategorical(features)
} else {
Expand Down
Loading

0 comments on commit da9b558

Please sign in to comment.