Skip to content

Commit

Permalink
chore: dont export some utils
Browse files Browse the repository at this point in the history
  • Loading branch information
stillmatic committed Aug 21, 2024
1 parent 408be30 commit 50dd42f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
13 changes: 2 additions & 11 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
package arboreal

import (
"encoding/json"
)

// MustNotError is a generic function to get the output of a function that returns
// mustNotError is a generic function to get the output of a function that returns
// a value and an error. If the error is not nil, it will panic.
func MustNotError[T any](input T, err error) T {
func mustNotError[T any](input T, err error) T {
if err != nil {
panic(err)
}
return input
}

func PrettyPrint(i interface{}) string {
s, _ := json.MarshalIndent(i, "", "\t")
return string(s)
}
23 changes: 16 additions & 7 deletions xgboost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ import (
"github.com/stretchr/testify/assert"
)

// mustNotError is a generic function to get the output of a function that returns
// a value and an error. If the error is not nil, it will panic.
func mustNotError[T any](input T, err error) T {
if err != nil {
panic(err)
}
return input
}

var vec = arboreal.SparseVector{
0: 2016.0,
1: 1.0,
Expand Down Expand Up @@ -66,7 +75,7 @@ func TestXGBoostJson(t *testing.T) {
// assert.Equal(t, res.Learner.GradientBooster.GetName(), "gbtree")
// gb := res.Learner.GradientBooster.(*arboreal.GBTree)
// assert.Equal(t, len(gb.Model.Trees), 100)
// assert.Equal(t, len(gb.Model.Trees), arboreal.MustNotError(strconv.Atoi(gb.Model.GbtreeModelParam.NumTrees)))
// assert.Equal(t, len(gb.Model.Trees), mustNotError(strconv.Atoi(gb.Model.GbtreeModelParam.NumTrees)))

// t0 := gb.Model.Trees[0]
// treeRes, err := t0.Predict(vec, gb.BaseScore)
Expand Down Expand Up @@ -110,7 +119,7 @@ func TestToy(t *testing.T) {
12: 40,
13: 38,
}
res0 := arboreal.MustNotError(res.Predict(sv0))
res0 := mustNotError(res.Predict(sv0))
t.Log((res0))
assert.InDelta(t, 0.4343974019963509, res0[0], 0.01)
sv1 := arboreal.SparseVector{
Expand All @@ -129,7 +138,7 @@ func TestToy(t *testing.T) {
12: 50,
13: 38,
}
res1 := arboreal.MustNotError(res.Predict(sv1))
res1 := mustNotError(res.Predict(sv1))
t.Log((res1))
assert.InDelta(t, 0.4694540577007751, res1[0], 0.01)
}
Expand Down Expand Up @@ -309,7 +318,7 @@ func BenchmarkXGBEndToEnd(b *testing.B) {
for i, input := range inputs {
floatInputs[i] = make([]float32, len(input))
for j, v := range input {
floatInputs[i][j] = float32(arboreal.MustNotError(strconv.ParseFloat(v, 32)))
floatInputs[i][j] = float32(mustNotError(strconv.ParseFloat(v, 32)))
}
}
for i := 0; i < b.N; i++ {
Expand All @@ -328,7 +337,7 @@ func BenchmarkXGBEndToEndConcurrent(b *testing.B) {
for i, input := range inputs {
floatInputs[i] = make([]float32, len(input))
for j, v := range input {
floatInputs[i][j] = float32(arboreal.MustNotError(strconv.ParseFloat(v, 32)))
floatInputs[i][j] = float32(mustNotError(strconv.ParseFloat(v, 32)))
}
}
for i := 0; i < b.N; i++ {
Expand All @@ -350,7 +359,7 @@ func BenchmarkXGBEndToEndOptimized(b *testing.B) {
for i, input := range inputs {
floatInputs[i] = make([]float32, len(input))
for j, v := range input {
floatInputs[i][j] = float32(arboreal.MustNotError(strconv.ParseFloat(v, 32)))
floatInputs[i][j] = float32(mustNotError(strconv.ParseFloat(v, 32)))
}
}
for i := 0; i < b.N; i++ {
Expand All @@ -370,7 +379,7 @@ func BenchmarkXGBEndToEndOptimizedConcurrent(b *testing.B) {
for i, input := range inputs {
floatInputs[i] = make([]float32, len(input))
for j, v := range input {
floatInputs[i][j] = float32(arboreal.MustNotError(strconv.ParseFloat(v, 32)))
floatInputs[i][j] = float32(mustNotError(strconv.ParseFloat(v, 32)))
}
}
for i := 0; i < b.N; i++ {
Expand Down
6 changes: 3 additions & 3 deletions xgboostio.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ func (l *learnerModelParam) UnmarshalJSON(b []byte) error {
if err := json.Unmarshal(b, &tmp); err != nil {
return err
}
l.BaseScore = float32(MustNotError(strconv.ParseFloat(tmp.BaseScore, 64)))
l.NumClass = MustNotError(strconv.Atoi(tmp.NumClass))
l.NumFeature = MustNotError(strconv.Atoi(tmp.NumFeature))
l.BaseScore = float32(mustNotError(strconv.ParseFloat(tmp.BaseScore, 64)))
l.NumClass = mustNotError(strconv.Atoi(tmp.NumClass))
l.NumFeature = mustNotError(strconv.Atoi(tmp.NumFeature))
return nil
}

Expand Down

0 comments on commit 50dd42f

Please sign in to comment.