-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.go
158 lines (143 loc) · 3.86 KB
/
model.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package nn
import (
"go4ml.xyz/base/fu"
"go4ml.xyz/base/model"
"go4ml.xyz/base/tables"
"go4ml.xyz/iokit"
"go4ml.xyz/nn/mx"
"go4ml.xyz/zorros"
"golang.org/x/xerrors"
"gopkg.in/yaml.v3"
"io"
)
// default batch size for general nn training
const DefaultBatchSize = 32
/*
Model is a ANN/Gym definition to train network
*/
type Model struct {
Network Block
Optimizer OptimizerConf
Loss mx.Loss
Input mx.Dimension
Seed int
BatchSize int
Predicted string
Context mx.Context // CPU by default
}
func (e Model) Feed(ds model.Dataset) model.FatModel {
return func(workout model.Workout) (*model.Report, error) {
return Train(e, ds, workout, DefaultModelMap)
}
}
/*
PredictionModel is the FeaturesMapper factory
*/
type PredictionModel struct {
features []string
predicts string
symbol, params iokit.Input
context mx.Context
}
/*
Features model uses when maps features
the same as Features in the training dataset
*/
func (pm PredictionModel) Features() []string { return pm.features }
/*
Column name model adds to result table when maps features.
By default it's 'Predicted'
*/
func (pm PredictionModel) Predicted() string { return pm.predicts }
/*
Returns new table with all original columns except features
adding one new column with prediction
*/
func (pm PredictionModel) FeaturesMapper(batchSize int) (fm tables.FeaturesMapper, err error) {
network, err := Load(pm.context, pm.symbol, pm.params, batchSize)
if err != nil {
return
}
fm = &FeaturesMapper{model: pm, network: network}
return
}
/*
Gpu changes context of prediction network to gpu enabled
*/
func (pm PredictionModel) Gpu(no ...int) model.PredictionModel {
pm.context = mx.GPU0
if len(no) > 0 {
pm.context = mx.Gpu(no[0])
}
return pm
}
/*
FeaturesMapper maps features to prediction
*/
type FeaturesMapper struct {
model PredictionModel
network *Network
}
/*
MapFeature returns new table with all original columns except features
adding one new column with prediction/calculation
*/
func (fm *FeaturesMapper) MapFeatures(t *tables.Table) (r *tables.Table, err error) {
var input tables.Matrix
if input, err = t.Matrix(fm.model.features, fm.network.BatchSize); err != nil {
return
}
out := make([]float32, fm.network.Output.Dim().Total())
outWidth := fm.network.Output.Dim().Total() / fm.network.BatchSize
if input.Width != fm.network.Input.Dim().Total()/fm.network.BatchSize {
return nil, xerrors.Errorf("features does not fit network input")
}
if t.Len() > fm.network.BatchSize {
return nil, xerrors.Errorf("batch size does not fit network input")
}
fm.network.Forward(input.Features, out)
return t.Except(fm.model.features...).With(tables.MatrixColumn(out[0:outWidth*t.Len()], t.Len()), fm.model.predicts), nil
}
/*
Close releases all bounded resources
*/
func (fm *FeaturesMapper) Close() error {
fm.network.Release()
return nil
}
func ObjectifyModel(c map[string]iokit.Input) (pm model.PredictionModel, err error) {
var rd io.ReadCloser
if _, ok := c[ModelPartInfo]; !ok {
return nil, zorros.New("it's not neural network model")
}
if rd, err = c[ModelPartInfo].Open(); err != nil {
return
}
defer rd.Close()
cf := map[string]interface{}{}
if err = yaml.NewDecoder(rd).Decode(&cf); err != nil {
return
}
m := PredictionModel{
symbol: c[ModelPartSymbol],
params: c[ModelPartParams],
features: fu.Strings(cf["features"]),
predicts: cf["predicts"].(string),
}
return m, nil
}
func Objectify(source iokit.Input, collection ...string) (fm model.GpuPredictionModel, err error) {
x := fu.Fnzs(fu.Fnzs(collection...), "model")
m, err := model.Objectify(source, model.ObjectifyMap{x: ObjectifyModel})
if err != nil {
return
}
return m[x].(model.GpuPredictionModel), nil
}
func LuckyObjectify(source iokit.Input, collection ...string) model.GpuPredictionModel {
fm, err := Objectify(source, collection...)
if err != nil {
panic(err)
}
return fm
}