Skip to content

Commit

Permalink
feat: add softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
stillmatic committed Jul 31, 2022
1 parent 5b86c6c commit 726d001
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 15 deletions.
7 changes: 0 additions & 7 deletions math.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ func max[T constraints.Ordered](a, b T) T {
return b
}

func sigmoid(x []float64) []float64 {
for i, v := range x {
x[i] = sigmoidSingle(v)
}
return x
}

func sigmoidSingle(x float64) float64 {
return 1.0 / (1.0 + math.Exp(-x))
}
Expand Down
24 changes: 18 additions & 6 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,32 @@ func (m *xgboostSchema) Predict(features *SparseVector) ([]float64, error) {
perClassScore := make([]float64, numClasses)
for i := 0; i < numClasses; i++ {
for j := 0; j < treesPerClass; j++ {
perClassScore[i] += internalResults[(i*treesPerClass)+j]
var idx int
// there has GOT to be a better way to do this
switch m.Learner.Objective.Name {
case "multi:softprob", "multi:softmax":
idx = i % numClasses
default:
idx = i*treesPerClass + j
}
perClassScore[i] += internalResults[idx]
}
switch m.Learner.Objective.Name {
case "reg:squarederror":
case "reg:squarederror", "reg:squaredlogerror", "reg:pseudohubererror":
// weirdly only applied to regression, not to binary classification
perClassScore[i] += m.Learner.LearnerModelParam.BaseScore
case "reg:logistic", "binary:logistic":
perClassScore[i] = sigmoidSingle(perClassScore[i])
}
}
// TODO: handle objective
// final post process
switch m.Learner.Objective.Name {
case "reg:logistic", "binary:logistic":
return sigmoid(perClassScore), nil
case "multi:softmax":
return perClassScore, nil
case "multi:softmax", "multi:softprob":
fmt.Println("softmax", perClassScore)
return Softmax(perClassScore), nil
case "reg:squarederror":
case "reg:squarederror", "reg:squaredlogerror", "reg:pseudohubererror":
return perClassScore, nil
default:
return nil, fmt.Errorf("unknown objective: %s", m.Learner.Objective)
Expand Down
1 change: 1 addition & 0 deletions testdata/toysoftmax.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"learner":{"attributes":{"best_iteration":"1","best_ntree_limit":"2"},"feature_names":["age","workclass","fnlwgt","education","education_num","occupation","relationship","race","sex","capital_gain","capital_loss","hours_per_week","native_country","wage_class"],"feature_types":["int","int","int","int","int","int","int","int","int","int","int","int","int","int"],"gradient_booster":{"model":{"gbtree_model_param":{"num_parallel_tree":"1","num_trees":"4","size_leaf_vector":"0"},"tree_info":[0,1,0,1],"trees":[{"base_weights":[6.4899564E-2,-9.98001E-1,8.168576E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[1,0,0],"id":0,"left_children":[1,-1,-1],"loss_changes":[9.649841E3,0E0,0E0],"parents":[2147483647,0,0],"right_children":[2,-1,-1],"split_conditions":[5E-1,-9.98001E-2,8.168576E-2],"split_indices":[6,0,0],"split_type":[0,0,0],"sum_hessian":[1.20715E4,5.0015E3,7.07E3],"tree_param":{"num_deleted":"0","num_feature":"14","num_nodes":"3","size_leaf_vector":"0"}},{"base_weights":[-6.673816E-2,9.9820286E-1,-8.163036E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[1,0,0],"id":1,"left_children":[1,-1,-1],"loss_changes":[9.677116E3,0E0,0E0],"parents":[2147483647,0,0],"right_children":[2,-1,-1],"split_conditions":[5E-1,9.9820286E-2,-8.1630364E-2],"split_indices":[6,0,0],"split_type":[0,0,0],"sum_hessian":[1.2121E4,5.007E3,7.114E3],"tree_param":{"num_deleted":"0","num_feature":"14","num_nodes":"3","size_leaf_vector":"0"}},{"base_weights":[5.8483366E-2,-9.075136E-1,7.3992443E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[1,0,0],"id":2,"left_children":[1,-1,-1],"loss_changes":[7.8570566E3,0E0,0E0],"parents":[2147483647,0,0],"right_children":[2,-1,-1],"split_conditions":[5E-1,-9.0751365E-2,7.3992446E-2],"split_indices":[6,0,0],"split_type":[0,0,0],"sum_hessian":[1.1933897E4,4.9361626E3,6.997735E3],"tree_param":{"num_deleted":"0","num_feature":"14","num_nodes":"3","size_leaf_vector":"0"}},{"base_weights":[-6.1782066E-2,9.0811884E-1,-7.3983353E-1],"categories":[],"categories_nodes":[],"categories_segments":[],"categories_sizes":[],"default_left":[1,0,0],"id":3,"left_children":[1,-1,-1],"loss_changes":[7.8770264E3,0E0,0E0],"parents":[2147483647,0,0],"right_children":[2,-1,-1],"split_conditions":[5E-1,9.0811886E-2,-7.398336E-2],"split_indices":[6,0,0],"split_type":[0,0,0],"sum_hessian":[1.1975647E4,4.927251E3,7.0483965E3],"tree_param":{"num_deleted":"0","num_feature":"14","num_nodes":"3","size_leaf_vector":"0"}}]},"name":"gbtree"},"learner_model_param":{"base_score":"5E-1","num_class":"2","num_feature":"14","num_target":"1"},"objective":{"name":"multi:softmax","softmax_multiclass_param":{"num_class":"2"}}},"version":[1,6,1]}
6 changes: 4 additions & 2 deletions xgboostio.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ func (l *learner) UnmarshalJSON(b []byte) error {
l.LearnerModelParam = tmp.LearnerModelParam
var err error
l.GradientBooster, err = parseGradientBooster(tmp.GradientBooster)
l.Objective, err = parseObjective(tmp.Objective)

if err != nil {
return errors.Wrapf(err, "failed to parse gradient booster")
}
l.Objective, err = parseObjective(tmp.Objective)
if err != nil {
return errors.Wrapf(err, "failed to parse objective")
}
return nil
}

Expand Down
47 changes: 47 additions & 0 deletions xgboostio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ func TestToy(t *testing.T) {
13: 38,
}
res0 := arboreal.MustNotError(res.Predict(sv0))
t.Log((res0))
assert.InDelta(t, 0.4343974019963509, res0[0], 0.01)
sv1 := &arboreal.SparseVector{
0: 38,
Expand All @@ -123,6 +124,7 @@ func TestToy(t *testing.T) {
13: 38,
}
res1 := arboreal.MustNotError(res.Predict(sv1))
t.Log((res1))
assert.InDelta(t, 0.4694540577007751, res1[0], 0.01)
}

Expand All @@ -134,6 +136,51 @@ func TestRegression(t *testing.T) {
assert.InDelta(t, 8.417279, score[0], 0.01)
}

func TestSoftprob(t *testing.T) {
res, err := arboreal.NewGBDTFromXGBoostJSON("testdata/toysoftmax.json")
assert.NoError(t, err)
smvec0 := &arboreal.SparseVector{
0: 25,
1: 2,
2: 226802,
3: 1,
4: 7,
5: 6,
6: 3,
7: 2,
8: 1,
9: 0,
10: 0,
11: 40,
12: 38,
13: 0,
}
score, err := res.Predict(smvec0)
assert.NoError(t, err)
assert.InDelta(t, 0.57720053, score[0], 0.01)
t.Log(score)
smvec1 := &arboreal.SparseVector{
0: 38,
1: 2,
2: 89814,
3: 11,
4: 9,
5: 4,
6: 0,
7: 4,
8: 1,
9: 0,
10: 0,
11: 50,
12: 38,
13: 0,
}
score, err = res.Predict(smvec1)
assert.NoError(t, err)
assert.InDelta(t, 0.40584144, score[0], 0.01)
t.Log(score)
}

func BenchmarkXGBoost(b *testing.B) {
res, err := arboreal.NewGBDTFromXGBoostJSON("testdata/mortgage_xgb.json")
assert.NoError(b, err)
Expand Down

0 comments on commit 726d001

Please sign in to comment.