Skip to content

Commit

Permalink
softmax decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Randall C. O'Reilly committed Jun 2, 2021
1 parent 02ab07f commit 6c1cf27
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 18 deletions.
11 changes: 11 additions & 0 deletions decoder/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Decoder

The decoder package provides standalone decoders that can sample variables from `emer` network layers.

# SoftMax

The `SoftMax` decoder is the best choice for a 1-hot classification decoder.

Call `Init` to initialize with number of categories and layers for input.

Call `Decode`
73 changes: 55 additions & 18 deletions decoder/decoder.go → decoder/softmax.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package decoder

import (
"math"
"sort"

"github.com/emer/emergent/emer"
"github.com/emer/etable/etensor"
Expand All @@ -18,6 +19,7 @@ type SoftMax struct {
Layers []emer.Layer `desc:"layers to decode"`
NCats int `desc:"number of different categories to decode"`
Units []Unit `desc:"unit values"`
Sorted []int `desc:"sorted list of indexes into Units, in descending order from strongest to weakest -- i.e., Sorted[0] has the most likely categorization, and its activity is Units[Sorted[0]].Act"`
NInputs int `desc:"number of inputs -- total sizes of layer inputs"`
Inputs []float32 `desc:"input values, copied from layers"`
Targ int `desc:"current target index of correct category"`
Expand All @@ -27,28 +29,43 @@ type SoftMax struct {

// Unit has variables for decoder unit
type Unit struct {
Act float32 `desc:"final activation = e^Ge / sum e^Ge"`
Net float32 `desc:"net input = sum x * w"`
Exp float32 `desc:"exp(Net)"`
DActDNet float32 `desc:"derivative of activation with respect to net input"`
Act float32 `desc:"final activation = e^Ge / sum e^Ge"`
Net float32 `desc:"net input = sum x * w"`
Exp float32 `desc:"exp(Net)"`
}

// Init initializes detector with number of categories and layers
func (sm *SoftMax) Init(ncats int, layers []emer.Layer) {
sm.NCats = ncats
sm.Units = make([]Unit, ncats)
// InitLayer initializes detector with number of categories and layers
func (sm *SoftMax) InitLayer(ncats int, layers []emer.Layer) {
sm.Layers = layers
sm.NInputs = 0
nin := 0
for _, ly := range sm.Layers {
sm.NInputs += ly.Shape().Len()
nin += ly.Shape().Len()
}
sm.Init(ncats, nin)
}

// Init initializes detector with number of categories and number of inputs
func (sm *SoftMax) Init(ncats, ninputs int) {
sm.NInputs = ninputs
sm.Lrate = 0.01
sm.NCats = ncats
sm.Units = make([]Unit, ncats)
sm.Sorted = make([]int, ncats)
sm.Inputs = make([]float32, sm.NInputs)
sm.Weights.SetShape([]int{sm.NCats, sm.NInputs}, nil, []string{"Cats", "Inputs"})
for i := range sm.Weights.Values {
sm.Weights.Values[i] = .1
}
}

// Decode decodes the given variable name from layers (forward pass)
func (sm *SoftMax) Decode(varNm string) {
// See Sorted list of indexes for the decoding output -- i.e., Sorted[0]
// is the most likely -- that is returned here as a convenience.
func (sm *SoftMax) Decode(varNm string) int {
sm.Input(varNm)
sm.Forward()
sm.Sort()
return sm.Sorted[0]
}

// Train trains the decoder with given target correct answer (0..NCats-1)
Expand Down Expand Up @@ -107,16 +124,36 @@ func (sm *SoftMax) Forward() {
}
for ui := range sm.Units {
u := &sm.Units[ui]
u.Act /= sum
u.Act = u.Exp / sum
}
}

// Sort updates Sorted indexes of the current Unit category activations sorted
// from highest to lowest. i.e., the 0-index value has the strongest
// decoded output category, 1 the next-strongest, etc.
func (sm *SoftMax) Sort() {
for i := range sm.Sorted {
sm.Sorted[i] = i
}
sort.Slice(sm.Sorted, func(i, j int) bool {
return sm.Units[sm.Sorted[i]].Act > sm.Units[sm.Sorted[j]].Act
})
}

// Back compute the backward error propagation pass
func (sm *SoftMax) Back() {
// for ui := range sm.Units {
// u := &sm.Units[ui]
// for ui := range sm.Units {
// u := &sm.Units[ui]
// }
// }
lr := sm.Lrate
for ui := range sm.Units {
u := &sm.Units[ui]
var del float32
if ui == sm.Targ {
del = lr * (1 - u.Act)
} else {
del = -lr * u.Act
}
off := ui * sm.NInputs
for j, in := range sm.Inputs {
sm.Weights.Values[off+j] += del * in
}
}
}
39 changes: 39 additions & 0 deletions decoder/softmax_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2021, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package decoder

import (
"testing"
)

func TestSoftMax(t *testing.T) {
dec := SoftMax{}
dec.Init(2, 2)
dec.Lrate = .1
for i := 0; i < 100; i++ {
trg := 0
if i%2 == 0 {
dec.Inputs[0] = 1
dec.Inputs[1] = 0
} else {
trg = 1
dec.Inputs[0] = 0
dec.Inputs[1] = 1
}
dec.Forward()
dec.Sort()
// fmt.Printf("%d\t%d\t%v", i, trg, dec.Sorted)
// for j := 0; j < 2; j++ {
// fmt.Printf("\t%g", dec.Units[j].Act)
// }
// fmt.Printf("\n")
if i > 2 {
if dec.Sorted[0] != trg {
t.Errorf("err: %d\t%d\t%v\n", i, trg, dec.Sorted)
}
}
dec.Train(trg)
}
}

0 comments on commit 6c1cf27

Please sign in to comment.