Skip to content

Support for opset 14 #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The goal of this package is to provide an easy way of running ONNX models in Go.
is intended for inference usage of ONNX models. The package can be used to load an `.onnx` file
and perform inference using the model described by this file.

Currently, we are implementing ONNX operation set 13, and we plan to add all opsets following this
Currently, ONNX opset versions 7 up until 14 are supported, and we plan to add all opsets following this
one as well. Feel free to contribute by implementing operators!

## Getting started
Expand Down
7 changes: 7 additions & 0 deletions ops/add/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@ package add

import (
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var addVersions = ops.OperatorVersions{
7: ops.NewOperatorConstructor(newAdd, 7, addTypeConstraints),
13: ops.NewOperatorConstructor(newAdd, 13, addTypeConstraints),
14: ops.NewOperatorConstructor(newAdd, 14, add14TypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return addVersions
}

var add14TypeConstraints = [][]tensor.Dtype{
ops.NumericTypes,
ops.NumericTypes,
}
1 change: 1 addition & 0 deletions ops/cumsum/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

var cumsumVersions = ops.OperatorVersions{
11: ops.NewOperatorConstructor(newCumSum, 11, cumsumTypeConstraints),
14: ops.NewOperatorConstructor(newCumSum, 14, cumsumTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
Expand Down
7 changes: 7 additions & 0 deletions ops/div/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@ package div

import (
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var divVersions = ops.OperatorVersions{
7: ops.NewOperatorConstructor(newDiv, 7, divTypeConstraints),
13: ops.NewOperatorConstructor(newDiv, 13, divTypeConstraints),
14: ops.NewOperatorConstructor(newDiv, 14, div14TypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return divVersions
}

var div14TypeConstraints = [][]tensor.Dtype{
ops.NumericTypes,
ops.NumericTypes,
}
79 changes: 58 additions & 21 deletions ops/gru/gru.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type GRU struct {
activations []string
direction ops.SequenceProcessDirection
hiddenSize int
layout int
linearBeforeReset bool
}

Expand All @@ -46,6 +47,7 @@ func newGRU(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
),
activations: []string{"sigmoid", "tanh"},
direction: ops.Forward,
layout: 0,
linearBeforeReset: false,
}
}
Expand Down Expand Up @@ -77,6 +79,13 @@ func (g *GRU) Init(n *onnx.NodeProto) error {
}
case ops.HiddenSizeAttr:
g.hiddenSize = int(attr.GetI())
case ops.LayoutAttr:
// 'layout' is supported since version 14
if g.Version() < 14 {
return ops.ErrInvalidAttribute(attr.GetName(), g)
}

g.layout = int(attr.GetI())
case "linear_before_reset":
g.linearBeforeReset = ops.Int64ToBool(attr.GetI())
default:
Expand All @@ -93,9 +102,10 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return nil, ops.ErrUnsupportedInput("sequence lens", g.BaseOperator)
}

X := inputs[0]
seqLength := X.Shape()[0]
batchSize := X.Shape()[1]
X, seqLength, batchSize, err := ops.ReshapeInputTensorBasedOnLayout(inputs[0], g.layout)
if err != nil {
return nil, err
}

Wz, Wr, Wh, err := g.getWeights(inputs[1])
if err != nil {
Expand All @@ -119,17 +129,36 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return nil, err
}

prevH := inputs[5]
if prevH == nil {
prevH = ops.ZeroTensor(1, batchSize, g.hiddenSize)
var prevH tensor.Tensor

if inputs[5] == nil {
if g.layout == 1 {
prevH = ops.ZeroTensor(batchSize, 1, g.hiddenSize)
} else {
prevH = ops.ZeroTensor(1, batchSize, g.hiddenSize)
}
} else {
var ok bool

prevH, ok = inputs[5].Clone().(tensor.Tensor)
if !ok {
return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[5].Clone())
}
}

// Extract the shape of the hidden dimensions without the bidirectional dimension, as
// we do not support bidirectional GRU yet.
shapeWithoutBidir := prevH.Shape().Clone()[1:]
// If layout is 1, this means batch size comes as first dimension, and
// we reshape it here to the default layout.
if g.layout == 1 {
numDirections := prevH.Shape()[1]
if err = prevH.Reshape(numDirections, batchSize, g.hiddenSize); err != nil {
return nil, err
}
}

err = prevH.Reshape(shapeWithoutBidir...)
if err != nil {
// Reshape the hidden tensor without the bidirectional dimension, as
// we do not support bidirectional RNN yet. This is the dimension at
// index 0.
if err = prevH.Reshape(prevH.Shape().Clone()[1:]...); err != nil {
return nil, err
}

Expand Down Expand Up @@ -184,21 +213,29 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
Y = outputs[0]
}

// Reshape the output so it adds the num_directions as specified by onnx.
err = Y.Reshape([]int{seqLength, 1, batchSize, g.hiddenSize}...)
if err != nil {
return nil, err
}

Yh, ok := prevH.Clone().(tensor.Tensor)
if !ok {
return nil, ops.ErrTypeAssert("tensor.Tensor", prevH.Clone())
}

// Reshape the output so it adds the num_directions as specified by onnx.
err = Yh.Reshape([]int{1, batchSize, g.hiddenSize}...)
if err != nil {
return nil, err
// Reshape the output according to the specified layout and re-add the
// num_directions dimension.
if g.layout == 1 {
if err = Y.Reshape(batchSize, seqLength, 1, g.hiddenSize); err != nil {
return nil, err
}

if err = Yh.Reshape(batchSize, 1, g.hiddenSize); err != nil {
return nil, err
}
} else {
if err = Y.Reshape(seqLength, 1, batchSize, g.hiddenSize); err != nil {
return nil, err
}

if err = Yh.Reshape(1, batchSize, g.hiddenSize); err != nil {
return nil, err
}
}

return []tensor.Tensor{Y, Yh}, nil
Expand Down
91 changes: 86 additions & 5 deletions ops/gru/gru_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ func TestGruInitUnkownAttr(t *testing.T) {

func TestGru(t *testing.T) {
tests := []struct {
version int64
node *onnx.NodeProto
inputs ops.InputFixture
expected []float32
err error
version int64
node *onnx.NodeProto
inputs ops.InputFixture
expected []float32
expectedShape tensor.Shape
err error
}{
{
7,
Expand All @@ -67,6 +68,7 @@ func TestGru(t *testing.T) {
gruInput0,
[]float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00},
nil,
nil,
},
{
7,
Expand All @@ -83,6 +85,7 @@ func TestGru(t *testing.T) {
gruInput0,
[]float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00},
nil,
nil,
},
{
7,
Expand All @@ -99,6 +102,7 @@ func TestGru(t *testing.T) {
gruInput1,
[]float32{0.44905475, 0.4406946, 0.43368173, 0.42782417},
nil,
nil,
},
{
7,
Expand All @@ -115,6 +119,45 @@ func TestGru(t *testing.T) {
gruInputNoBNoH,
[]float32{0.24553154, 0.24553154, 0.24553154, 0.24553154},
nil,
nil,
},
{
14,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "activation_alpha", Floats: []float32{}},
{Name: "activation_beta", Floats: []float32{}},
{Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}},
{Name: "direction", S: []byte("forward")},
{Name: "hidden_size", I: 4},
{Name: "layout", I: 1},
{Name: "linear_before_reset", I: 0},
},
},
gruInputBatchFirst,
[]float32{0.0066928267, 8.34465e-07, 0, 0, 8.34465e-07, 0, 0, 0},
// shape [batch_size, sequence_length, num_directions, hidden_size]
[]int{2, 5, 1, 4},
nil,
},
{
14,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "activation_alpha", Floats: []float32{}},
{Name: "activation_beta", Floats: []float32{}},
{Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}},
{Name: "direction", S: []byte("forward")},
{Name: "hidden_size", I: 4},
{Name: "layout", I: 1},
{Name: "linear_before_reset", I: 0},
},
},
gruInputBatchFirstNoH,
[]float32{0.0066928267, 8.34465e-07, 0, 0, 8.34465e-07, 0, 0, 0},
// shape [batch_size, sequence_length, num_directions, hidden_size]
[]int{2, 5, 1, 4},
nil,
},
}

Expand All @@ -130,6 +173,10 @@ func TestGru(t *testing.T) {

if err == nil {
assert.Equal(t, test.expected, res[1].Data())

if test.expectedShape != nil {
assert.Equal(t, test.expectedShape, res[0].Shape())
}
}
}
}
Expand Down Expand Up @@ -313,6 +360,40 @@ func gruInputNoBNoH() []tensor.Tensor {
return inputs
}

func gruInputBatchFirst() []tensor.Tensor {
return []tensor.Tensor{
// Input X: (batch_size, sequence_length, input_size).
ops.Float32TensorFixture(2, 5, 3),
// Input W: (num_directions, 3 * hidden_size, input_size).
ops.Float32TensorFixture(1, 12, 3),
// Input R: (num_directions, 3 * hidden_size, hidden_size).
ops.Float32TensorFixture(1, 12, 4),
// Input B: not provided.
nil,
// Input sequence_lens: not supported
nil,
// Input initial_h: (batch_size, num_directions, hidden_size)
ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(2, 1, 4)), 2, 1, 4),
}
}

func gruInputBatchFirstNoH() []tensor.Tensor {
return []tensor.Tensor{
// Input X: (batch_size, sequence_length, input_size).
ops.Float32TensorFixture(2, 5, 3),
// Input W: (num_directions, 3 * hidden_size, input_size).
ops.Float32TensorFixture(1, 12, 3),
// Input R: (num_directions, 3 * hidden_size, hidden_size).
ops.Float32TensorFixture(1, 12, 4),
// Input B: not provided.
nil,
// Input sequence_lens: not supported
nil,
// Input initial_h: not provided
nil,
}
}

func GRUOnnxNodeProtoFixture() *onnx.NodeProto {
return &onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
Expand Down
3 changes: 2 additions & 1 deletion ops/gru/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package gru
import "github.com/advancedclimatesystems/gonnx/ops"

var gruVersions = ops.OperatorVersions{
7: ops.NewOperatorConstructor(newGRU, 7, gruTypeConstraints),
7: ops.NewOperatorConstructor(newGRU, 7, gruTypeConstraints),
14: ops.NewOperatorConstructor(newGRU, 14, gruTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
Expand Down
1 change: 1 addition & 0 deletions ops/identity/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

var identityVersions = ops.OperatorVersions{
13: ops.NewOperatorConstructor(newIdentity, 13, identityTypeConstraints),
14: ops.NewOperatorConstructor(newIdentity, 14, identityTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
Expand Down
Loading