From 354ab24be072e6ceb680f7e1af3ad22bc8f6212f Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 30 Mar 2025 11:10:31 +0200 Subject: [PATCH 1/4] Support for opset 14 --- ops/add/versions.go | 7 +++ ops/cumsum/versions.go | 1 + ops/div/versions.go | 7 +++ ops/gru/gru.go | 83 +++++++++++++++++++++++++------- ops/gru/gru_test.go | 91 ++++++++++++++++++++++++++++++++++-- ops/gru/versions.go | 3 +- ops/lstm/lstm.go | 87 +++++++++++++++++++++++++++++----- ops/lstm/lstm_test.go | 104 +++++++++++++++++++++++++++++++++++++++-- ops/lstm/versions.go | 3 +- ops/mul/versions.go | 11 ++++- ops/recurrent_utils.go | 1 + ops/relu/versions.go | 8 +++- ops/rnn/rnn.go | 72 +++++++++++++++++++++++----- ops/rnn/rnn_test.go | 95 +++++++++++++++++++++++++++++++++++-- ops/rnn/versions.go | 3 +- ops/sub/versions.go | 11 ++++- ops_test.go | 12 ++--- opset.go | 4 +- 18 files changed, 533 insertions(+), 70 deletions(-) diff --git a/ops/add/versions.go b/ops/add/versions.go index 3a588ee..f413f58 100644 --- a/ops/add/versions.go +++ b/ops/add/versions.go @@ -2,13 +2,20 @@ package add import ( "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" ) var addVersions = ops.OperatorVersions{ 7: ops.NewOperatorConstructor(newAdd, 7, addTypeConstraints), 13: ops.NewOperatorConstructor(newAdd, 13, addTypeConstraints), + 14: ops.NewOperatorConstructor(newAdd, 14, add14TypeConstraints), } func GetVersions() ops.OperatorVersions { return addVersions } + +var add14TypeConstraints = [][]tensor.Dtype{ + ops.NumericTypes, + ops.NumericTypes, +} diff --git a/ops/cumsum/versions.go b/ops/cumsum/versions.go index 89fceab..aa2763d 100644 --- a/ops/cumsum/versions.go +++ b/ops/cumsum/versions.go @@ -6,6 +6,7 @@ import ( var cumsumVersions = ops.OperatorVersions{ 11: ops.NewOperatorConstructor(newCumSum, 11, cumsumTypeConstraints), + 14: ops.NewOperatorConstructor(newCumSum, 14, cumsumTypeConstraints), } func GetVersions() ops.OperatorVersions { diff --git a/ops/div/versions.go b/ops/div/versions.go index cd1cb89..8ab80ec 100644 --- a/ops/div/versions.go +++ b/ops/div/versions.go @@ -2,13 +2,20 @@ package div import ( "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" ) var divVersions = ops.OperatorVersions{ 7: ops.NewOperatorConstructor(newDiv, 7, divTypeConstraints), 13: ops.NewOperatorConstructor(newDiv, 13, divTypeConstraints), + 14: ops.NewOperatorConstructor(newDiv, 14, div14TypeConstraints), } func GetVersions() ops.OperatorVersions { return divVersions } + +var div14TypeConstraints = [][]tensor.Dtype{ + ops.NumericTypes, + ops.NumericTypes, +} diff --git a/ops/gru/gru.go b/ops/gru/gru.go index f5d97df..f73d7d1 100644 --- a/ops/gru/gru.go +++ b/ops/gru/gru.go @@ -31,6 +31,7 @@ type GRU struct { activations []string direction ops.SequenceProcessDirection hiddenSize int + layout int linearBeforeReset bool } @@ -46,6 +47,7 @@ func newGRU(version int, typeConstraints [][]tensor.Dtype) ops.Operator { ), activations: []string{"sigmoid", "tanh"}, direction: ops.Forward, + layout: 0, linearBeforeReset: false, } } @@ -77,6 +79,13 @@ func (g *GRU) Init(n *onnx.NodeProto) error { } case ops.HiddenSizeAttr: g.hiddenSize = int(attr.GetI()) + case ops.LayoutAttr: + // 'layout' is supported since version 14 + if g.Version() < 14 { + return ops.ErrInvalidAttribute(attr.GetName(), g) + } + + g.layout = int(attr.GetI()) case "linear_before_reset": g.linearBeforeReset = ops.Int64ToBool(attr.GetI()) default: @@ -94,8 +103,27 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } X := inputs[0] - seqLength := X.Shape()[0] - batchSize := X.Shape()[1] + + var seqLength int + + var batchSize int + + // The 'layout' parameter handles whether or not the batch dimension comes + // first in the tensor. If this is the case, we reshape it here in + // in the beginning of the operation, and reverse it at the end of the operation. + if g.layout == 1 { + seqLength = X.Shape()[1] + batchSize = X.Shape()[0] + inputSize := X.Shape()[2] + + err := X.Reshape(seqLength, batchSize, inputSize) + if err != nil { + return nil, err + } + } else { + seqLength = X.Shape()[0] + batchSize = X.Shape()[1] + } Wz, Wr, Wh, err := g.getWeights(inputs[1]) if err != nil { @@ -121,15 +149,26 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { prevH := inputs[5] if prevH == nil { - prevH = ops.ZeroTensor(1, batchSize, g.hiddenSize) + if g.layout == 1 { + prevH = ops.ZeroTensor(batchSize, 1, g.hiddenSize) + } else { + prevH = ops.ZeroTensor(1, batchSize, g.hiddenSize) + } } - // Extract the shape of the hidden dimensions without the bidirectional dimension, as - // we do not support bidirectional GRU yet. - shapeWithoutBidir := prevH.Shape().Clone()[1:] + // If layout is 1, this means batch size comes as first dimension, and + // we reshape it here to the default layout. + if g.layout == 1 { + numDirections := prevH.Shape()[1] + if err = prevH.Reshape(numDirections, batchSize, g.hiddenSize); err != nil { + return nil, err + } + } - err = prevH.Reshape(shapeWithoutBidir...) - if err != nil { + // Reshape the hidden tensor without the bidirectional dimension, as + // we do not support bidirectional RNN yet. This is the dimension at + // index 0. + if err = prevH.Reshape(prevH.Shape().Clone()[1:]...); err != nil { return nil, err } @@ -184,21 +223,29 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { Y = outputs[0] } - // Reshape the output so it adds the num_directions as specified by onnx. - err = Y.Reshape([]int{seqLength, 1, batchSize, g.hiddenSize}...) - if err != nil { - return nil, err - } - Yh, ok := prevH.Clone().(tensor.Tensor) if !ok { return nil, ops.ErrTypeAssert("tensor.Tensor", prevH.Clone()) } - // Reshape the output so it adds the num_directions as specified by onnx. - err = Yh.Reshape([]int{1, batchSize, g.hiddenSize}...) - if err != nil { - return nil, err + // Reshape the output according to the specified layout and re-add the + // num_directions dimension. + if g.layout == 1 { + if err = Y.Reshape(batchSize, seqLength, 1, g.hiddenSize); err != nil { + return nil, err + } + + if err = Yh.Reshape(batchSize, 1, g.hiddenSize); err != nil { + return nil, err + } + } else { + if err = Y.Reshape(seqLength, 1, batchSize, g.hiddenSize); err != nil { + return nil, err + } + + if err = Yh.Reshape(1, batchSize, g.hiddenSize); err != nil { + return nil, err + } } return []tensor.Tensor{Y, Yh}, nil diff --git a/ops/gru/gru_test.go b/ops/gru/gru_test.go index 789f509..8ea6120 100644 --- a/ops/gru/gru_test.go +++ b/ops/gru/gru_test.go @@ -46,11 +46,12 @@ func TestGruInitUnkownAttr(t *testing.T) { func TestGru(t *testing.T) { tests := []struct { - version int64 - node *onnx.NodeProto - inputs ops.InputFixture - expected []float32 - err error + version int64 + node *onnx.NodeProto + inputs ops.InputFixture + expected []float32 + expectedShape tensor.Shape + err error }{ { 7, @@ -67,6 +68,7 @@ func TestGru(t *testing.T) { gruInput0, []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, nil, + nil, }, { 7, @@ -83,6 +85,7 @@ func TestGru(t *testing.T) { gruInput0, []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, nil, + nil, }, { 7, @@ -99,6 +102,7 @@ func TestGru(t *testing.T) { gruInput1, []float32{0.44905475, 0.4406946, 0.43368173, 0.42782417}, nil, + nil, }, { 7, @@ -115,6 +119,45 @@ func TestGru(t *testing.T) { gruInputNoBNoH, []float32{0.24553154, 0.24553154, 0.24553154, 0.24553154}, nil, + nil, + }, + { + 14, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + {Name: "layout", I: 1}, + {Name: "linear_before_reset", I: 0}, + }, + }, + gruInputBatchFirst, + []float32{0.0066928267, 8.34465e-07, 0, 0, 8.34465e-07, 0, 0, 0}, + // shape [batch_size, sequence_length, num_directions, hidden_size] + []int{2, 5, 1, 4}, + nil, + }, + { + 14, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + {Name: "layout", I: 1}, + {Name: "linear_before_reset", I: 0}, + }, + }, + gruInputBatchFirstNoH, + []float32{0.0066928267, 8.34465e-07, 0, 0, 8.34465e-07, 0, 0, 0}, + // shape [batch_size, sequence_length, num_directions, hidden_size] + []int{2, 5, 1, 4}, + nil, }, } @@ -130,6 +173,10 @@ func TestGru(t *testing.T) { if err == nil { assert.Equal(t, test.expected, res[1].Data()) + + if test.expectedShape != nil { + assert.Equal(t, test.expectedShape, res[0].Shape()) + } } } } @@ -313,6 +360,40 @@ func gruInputNoBNoH() []tensor.Tensor { return inputs } +func gruInputBatchFirst() []tensor.Tensor { + return []tensor.Tensor{ + // Input X: (batch_size, sequence_length, input_size). + ops.Float32TensorFixture(2, 5, 3), + // Input W: (num_directions, 3 * hidden_size, input_size). + ops.Float32TensorFixture(1, 12, 3), + // Input R: (num_directions, 3 * hidden_size, hidden_size). + ops.Float32TensorFixture(1, 12, 4), + // Input B: not provided. + nil, + // Input sequence_lens: not supported + nil, + // Input initial_h: (batch_size, num_directions, hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(2, 1, 4)), 2, 1, 4), + } +} + +func gruInputBatchFirstNoH() []tensor.Tensor { + return []tensor.Tensor{ + // Input X: (batch_size, sequence_length, input_size). + ops.Float32TensorFixture(2, 5, 3), + // Input W: (num_directions, 3 * hidden_size, input_size). + ops.Float32TensorFixture(1, 12, 3), + // Input R: (num_directions, 3 * hidden_size, hidden_size). + ops.Float32TensorFixture(1, 12, 4), + // Input B: not provided. + nil, + // Input sequence_lens: not supported + nil, + // Input initial_h: not provided + nil, + } +} + func GRUOnnxNodeProtoFixture() *onnx.NodeProto { return &onnx.NodeProto{ Attribute: []*onnx.AttributeProto{ diff --git a/ops/gru/versions.go b/ops/gru/versions.go index e157efa..1da6924 100644 --- a/ops/gru/versions.go +++ b/ops/gru/versions.go @@ -3,7 +3,8 @@ package gru import "github.com/advancedclimatesystems/gonnx/ops" var gruVersions = ops.OperatorVersions{ - 7: ops.NewOperatorConstructor(newGRU, 7, gruTypeConstraints), + 7: ops.NewOperatorConstructor(newGRU, 7, gruTypeConstraints), + 14: ops.NewOperatorConstructor(newGRU, 14, gruTypeConstraints), } func GetVersions() ops.OperatorVersions { diff --git a/ops/lstm/lstm.go b/ops/lstm/lstm.go index f7bc057..8bc80c8 100644 --- a/ops/lstm/lstm.go +++ b/ops/lstm/lstm.go @@ -32,6 +32,7 @@ type LSTM struct { activations []string direction ops.SequenceProcessDirection hiddenSize int + layout int inputForget bool outputs []string @@ -49,6 +50,7 @@ func newLSTM(version int, typeConstraints [][]tensor.Dtype) ops.Operator { ), activations: []string{"sigmoid", "tanh", "tanh"}, direction: ops.Forward, + layout: 0, inputForget: false, outputs: []string{"Y", "Y_h", "Y_c"}, } @@ -76,6 +78,13 @@ func (l *LSTM) Init(n *onnx.NodeProto) error { if l.direction != ops.Forward { return ops.ErrUnsupportedAttribute(attr.GetName(), l) } + case ops.LayoutAttr: + // 'layout' is supported since version 14 + if l.Version() < 14 { + return ops.ErrInvalidAttribute(attr.GetName(), l) + } + + l.layout = int(attr.GetI()) case ops.HiddenSizeAttr: l.hiddenSize = int(attr.GetI()) case "input_forget": @@ -97,8 +106,27 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } X := inputs[0] - seqLength := X.Shape()[0] - batchSize := X.Shape()[1] + + var seqLength int + + var batchSize int + + // The 'layout' parameter handles whether or not the batch dimension comes + // first in the tensor. If this is the case, we reshape it here in + // in the beginning of the operation, and reverse it at the end of the operation. + if l.layout == 1 { + seqLength = X.Shape()[1] + batchSize = X.Shape()[0] + inputSize := X.Shape()[2] + + err := X.Reshape(seqLength, batchSize, inputSize) + if err != nil { + return nil, err + } + } else { + seqLength = X.Shape()[0] + batchSize = X.Shape()[1] + } Wi, Wo, Wf, Wc, err := l.getWeights(inputs[1]) if err != nil { @@ -124,12 +152,20 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { Ht := inputs[5] if Ht == nil { - Ht = ops.ZeroTensor(1, batchSize, l.hiddenSize) + if l.layout == 1 { + Ht = ops.ZeroTensor(batchSize, 1, l.hiddenSize) + } else { + Ht = ops.ZeroTensor(1, batchSize, l.hiddenSize) + } } Ct := inputs[6] if Ct == nil { - Ct = ops.ZeroTensor(1, batchSize, l.hiddenSize) + if l.layout == 1 { + Ct = ops.ZeroTensor(batchSize, 1, l.hiddenSize) + } else { + Ct = ops.ZeroTensor(1, batchSize, l.hiddenSize) + } } var Pi, Po, Pf tensor.Tensor @@ -142,6 +178,19 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } } + // If layout is 1, this means batch size comes as first dimension, and + // we reshape the hidden states here to the default layout. + if l.layout == 1 { + numDirections := Ht.Shape()[1] + if err = Ht.Reshape(numDirections, batchSize, l.hiddenSize); err != nil { + return nil, err + } + + if err = Ct.Reshape(numDirections, batchSize, l.hiddenSize); err != nil { + return nil, err + } + } + // Reshape the hidden and cell tensor without the bidirectional dimension, as // we do not support bidirectional yet. This is the dimension at // index 0. @@ -232,16 +281,30 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // Reshape the hidden tensor without the bidirectional dimension, as // we do not support bidirectional RNN yet. This is the dimension at // index 0. - if err = Y.Reshape(seqLength, 1, batchSize, l.hiddenSize); err != nil { - return nil, err - } + if l.layout == 1 { + if err = Y.Reshape(batchSize, seqLength, 1, l.hiddenSize); err != nil { + return nil, err + } - if err = Yh.Reshape(1, batchSize, l.hiddenSize); err != nil { - return nil, err - } + if err = Yh.Reshape(batchSize, 1, l.hiddenSize); err != nil { + return nil, err + } - if err = Yc.Reshape(1, batchSize, l.hiddenSize); err != nil { - return nil, err + if err = Yc.Reshape(batchSize, 1, l.hiddenSize); err != nil { + return nil, err + } + } else { + if err = Y.Reshape(seqLength, 1, batchSize, l.hiddenSize); err != nil { + return nil, err + } + + if err = Yh.Reshape(1, batchSize, l.hiddenSize); err != nil { + return nil, err + } + + if err = Yc.Reshape(1, batchSize, l.hiddenSize); err != nil { + return nil, err + } } outputMap := map[string]tensor.Tensor{ diff --git a/ops/lstm/lstm_test.go b/ops/lstm/lstm_test.go index 4d66882..db4c140 100644 --- a/ops/lstm/lstm_test.go +++ b/ops/lstm/lstm_test.go @@ -48,11 +48,12 @@ func TestLSTMInitUnkownAttr(t *testing.T) { func TestLSTM(t *testing.T) { tests := []struct { - version int64 - attrs *onnx.NodeProto - inputs ops.InputFixture - expected []float32 - err error + version int64 + attrs *onnx.NodeProto + inputs ops.InputFixture + expected []float32 + expectedShape tensor.Shape + err error }{ { 7, @@ -69,6 +70,7 @@ func TestLSTM(t *testing.T) { lstmInput0, []float32{0.9159305, 0.9356764, 0.87070554, 0.84180677}, nil, + nil, }, { 7, @@ -85,6 +87,7 @@ func TestLSTM(t *testing.T) { lstmInput0, []float32{1.7530097, 1.7829735, 1.6231446, 1.5197954}, nil, + nil, }, { 7, @@ -101,6 +104,7 @@ func TestLSTM(t *testing.T) { lstmInput1, []float32{10.598255, 10.547241, 10.214846, 10.267471}, nil, + nil, }, { 7, @@ -117,6 +121,7 @@ func TestLSTM(t *testing.T) { lstmInputNoBNoH, []float32{8.276371, 8.291079, 8.161418, 7.7900877}, nil, + nil, }, { 7, @@ -133,6 +138,45 @@ func TestLSTM(t *testing.T) { lstmInputPeepholes, []float32{0.99891853, 0.99994266, 0.9995524, 0.99171203}, nil, + nil, + }, + { + 14, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + {Name: "layout", I: 1}, + }, + Output: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInputBatchFirst, + []float32{0.94253653, 0.98116714, 0.9265363, 0.9144332, 0.93192303, 0.97583324, 0.91199535, 0.8959566}, + // shape [batch_size, sequence_length, num_directions, hidden_size] + []int{2, 10, 1, 4}, + nil, + }, + { + 14, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + {Name: "layout", I: 1}, + }, + Output: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInputBatchFirstNoHNoC, + []float32{0.99986863, 0.9999849, 0.98989785, 0.99847853, 0.9999217, 0.9999911, 0.9903357, 0.998897}, + // shape [batch_size, sequence_length, num_directions, hidden_size] + []int{2, 10, 1, 4}, + nil, }, } @@ -148,6 +192,10 @@ func TestLSTM(t *testing.T) { if err == nil { assert.Equal(t, test.expected, res[1].Data()) + + if test.expectedShape != nil { + assert.Equal(t, test.expectedShape, res[0].Shape()) + } } } } @@ -404,6 +452,52 @@ func lstmInputPeepholes() []tensor.Tensor { } } +func lstmInputBatchFirst() []tensor.Tensor { + r := rand.New(rand.NewSource(13)) + + return []tensor.Tensor{ + // Input X: (batch_size, sequence_length, input_size). + ops.RandomFloat32TensorFixture(r, 2, 10, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(r, 1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(r, 1, 16, 4), + // Input B. + nil, + // Input sequence_lens: not supported. + nil, + // Input initial_h: (batch_size, num_directions, hidden_size). + ops.RandomFloat32TensorFixture(r, 2, 1, 4), + // Input initial_c: (batch_size, num_directions, hidden_size). + ops.RandomFloat32TensorFixture(r, 2, 1, 4), + // Input P: peephole weights. + nil, + } +} + +func lstmInputBatchFirstNoHNoC() []tensor.Tensor { + r := rand.New(rand.NewSource(13)) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(r, 2, 10, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(r, 1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(r, 1, 16, 4), + // Input B. + nil, + // Input sequence_lens: not supported. + nil, + // Input initial_h. + nil, + // Input initial_c. + nil, + // Input P: (num_directions, 3 * hidden_size). + ops.RandomFloat32TensorFixture(r, 1, 12), + } +} + func LSTMOnnxNodeProtoFixture() *onnx.NodeProto { return &onnx.NodeProto{ Attribute: []*onnx.AttributeProto{ diff --git a/ops/lstm/versions.go b/ops/lstm/versions.go index 8872f69..27db895 100644 --- a/ops/lstm/versions.go +++ b/ops/lstm/versions.go @@ -3,7 +3,8 @@ package lstm import "github.com/advancedclimatesystems/gonnx/ops" var lstmVersions = ops.OperatorVersions{ - 7: ops.NewOperatorConstructor(newLSTM, 7, lstmTypeConstraints), + 7: ops.NewOperatorConstructor(newLSTM, 7, lstmTypeConstraints), + 14: ops.NewOperatorConstructor(newLSTM, 14, lstmTypeConstraints), } func GetVersions() ops.OperatorVersions { diff --git a/ops/mul/versions.go b/ops/mul/versions.go index 2989254..f6fb827 100644 --- a/ops/mul/versions.go +++ b/ops/mul/versions.go @@ -1,12 +1,21 @@ package mul -import "github.com/advancedclimatesystems/gonnx/ops" +import ( + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) var mulVersions = ops.OperatorVersions{ 7: ops.NewOperatorConstructor(newMul, 7, mulTypeConstraints), 13: ops.NewOperatorConstructor(newMul, 13, mulTypeConstraints), + 14: ops.NewOperatorConstructor(newMul, 14, mul14TypeConstraints), } func GetVersions() ops.OperatorVersions { return mulVersions } + +var mul14TypeConstraints = [][]tensor.Dtype{ + ops.NumericTypes, + ops.NumericTypes, +} diff --git a/ops/recurrent_utils.go b/ops/recurrent_utils.go index 98564f1..8a86424 100644 --- a/ops/recurrent_utils.go +++ b/ops/recurrent_utils.go @@ -23,6 +23,7 @@ const ( ClipAttr = "clip" DirectionAttr = "direction" HiddenSizeAttr = "hidden_size" + LayoutAttr = "layout" ) // ExtractMatrices extracts a given number of matrices from tensor M. diff --git a/ops/relu/versions.go b/ops/relu/versions.go index 89f6f6e..aa95e8e 100644 --- a/ops/relu/versions.go +++ b/ops/relu/versions.go @@ -1,12 +1,18 @@ package relu -import "github.com/advancedclimatesystems/gonnx/ops" +import ( + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) var reluVersions = ops.OperatorVersions{ 6: ops.NewOperatorConstructor(newRelu, 6, reluTypeConstraints), 13: ops.NewOperatorConstructor(newRelu, 13, reluTypeConstraints), + 14: ops.NewOperatorConstructor(newRelu, 14, relu14TypeConstraints), } func GetVersions() ops.OperatorVersions { return reluVersions } + +var relu14TypeConstraints = [][]tensor.Dtype{{tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}} diff --git a/ops/rnn/rnn.go b/ops/rnn/rnn.go index 2649e24..41fbc67 100644 --- a/ops/rnn/rnn.go +++ b/ops/rnn/rnn.go @@ -30,6 +30,7 @@ type RNN struct { activations []string direction ops.SequenceProcessDirection hiddenSize int + layout int } // newRNN creates a new rnn operator. @@ -44,6 +45,7 @@ func newRNN(version int, typeConstraints [][]tensor.Dtype) ops.Operator { ), activations: []string{"tanh"}, direction: ops.Forward, + layout: 0, } } @@ -71,6 +73,13 @@ func (r *RNN) Init(n *onnx.NodeProto) error { } case ops.HiddenSizeAttr: r.hiddenSize = int(attr.GetI()) + case ops.LayoutAttr: + // 'layout' is supported since version 14 + if r.Version() < 14 { + return ops.ErrInvalidAttribute(attr.GetName(), r) + } + + r.layout = int(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), r) } @@ -86,8 +95,27 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } X := inputs[0] - seqLength := X.Shape()[0] - batchSize := X.Shape()[1] + + var seqLength int + + var batchSize int + + // The 'layout' parameter handles whether or not the batch dimension comes + // first in the tensor. If this is the case, we reshape it here in + // in the beginning of the operation, and reverse it at the end of the operation. + if r.layout == 1 { + seqLength = X.Shape()[1] + batchSize = X.Shape()[0] + inputSize := X.Shape()[2] + + err := X.Reshape(seqLength, batchSize, inputSize) + if err != nil { + return nil, err + } + } else { + seqLength = X.Shape()[0] + batchSize = X.Shape()[1] + } Wi, err := r.getWeights(inputs[1]) if err != nil { @@ -113,7 +141,20 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { Ht := inputs[5] if Ht == nil { - Ht = ops.ZeroTensor(1, batchSize, r.hiddenSize) + if r.layout == 1 { + Ht = ops.ZeroTensor(batchSize, 1, r.hiddenSize) + } else { + Ht = ops.ZeroTensor(1, batchSize, r.hiddenSize) + } + } + + // If layout is 1, this means batch size comes as first dimension, and + // we reshape it here to the default layout. + if r.layout == 1 { + numDirections := Ht.Shape()[1] + if err = Ht.Reshape(numDirections, batchSize, r.hiddenSize); err != nil { + return nil, err + } } // Reshape the hidden tensor without the bidirectional dimension, as @@ -159,15 +200,24 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, ops.ErrTypeAssert("tensor.Tensor", Ht.Clone()) } - // Reshape the hidden tensor without the bidirectional dimension, as - // we do not support bidirectional RNN yet. This is the dimension at - // index 0. - if err = Y.Reshape(seqLength, 1, batchSize, r.hiddenSize); err != nil { - return nil, err - } + // Reshape the output according to the specified layout and re-add the + // num_directions dimension. + if r.layout == 1 { + if err = Y.Reshape(batchSize, seqLength, 1, r.hiddenSize); err != nil { + return nil, err + } - if err = Yh.Reshape(1, batchSize, r.hiddenSize); err != nil { - return nil, err + if err = Yh.Reshape(batchSize, 1, r.hiddenSize); err != nil { + return nil, err + } + } else { + if err = Y.Reshape(seqLength, 1, batchSize, r.hiddenSize); err != nil { + return nil, err + } + + if err = Yh.Reshape(1, batchSize, r.hiddenSize); err != nil { + return nil, err + } } return []tensor.Tensor{Y, Yh}, nil diff --git a/ops/rnn/rnn_test.go b/ops/rnn/rnn_test.go index fb874ff..d442cbd 100644 --- a/ops/rnn/rnn_test.go +++ b/ops/rnn/rnn_test.go @@ -36,11 +36,12 @@ func TestRNNInitUnknownAttr(t *testing.T) { func TestRNN(t *testing.T) { tests := []struct { - version int64 - attrs *onnx.NodeProto - inputs ops.InputFixture - expected []float32 - err error + version int64 + attrs *onnx.NodeProto + inputs ops.InputFixture + expected []float32 + expectedShape tensor.Shape + err error }{ { 7, @@ -56,6 +57,7 @@ func TestRNN(t *testing.T) { rnnInput0, []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, nil, + nil, }, { 7, @@ -71,6 +73,7 @@ func TestRNN(t *testing.T) { rnnInput0, []float32{0.82048327, 0.922734, 0.89050114, 0.8620579}, nil, + nil, }, { 7, @@ -87,6 +90,7 @@ func TestRNN(t *testing.T) { rnnInput0, []float32{1.0667435, 2.328037, 1.7986122, 1.545068}, nil, + nil, }, { 7, @@ -102,6 +106,7 @@ func TestRNN(t *testing.T) { rnnInput1, []float32{0.99996024, 0.9999855, 0.99998087, 0.9999288, 0.9997511, 0.99918234, 0.99999964, 0.9999981, 0.9997658, 0.9999618, 0.9998762, 0.9999353, 0.9999194, 0.9999428, 0.9997284, 0.9982606, 0.999999, 0.9999897, 0.99964744, 0.9998234, 0.99997497, 0.9999893, 0.9999906, 0.9999812, 0.99983937, 0.99967873, 0.9999998, 0.9999965, 0.9999516, 0.9999541}, nil, + nil, }, { 7, @@ -118,6 +123,7 @@ func TestRNN(t *testing.T) { // Same values as first test, but B is initialized automatically. []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, nil, + nil, }, { 7, @@ -134,6 +140,43 @@ func TestRNN(t *testing.T) { // Same values as first test, but B and H are initialized automatically. []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, nil, + nil, + }, + { + 14, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 5}, + {Name: "layout", I: 1}, + }, + }, + rnnInputBatchFirst, + []float32{0.98516846, 0.9842066, 0.99648196, 0.999319, 0.95900106, 0.96534646, 0.9786028, 0.99493814, 0.9979183, 0.9448906}, + // shape [batch_size, sequence_length, num_directions, hidden_size] + []int{2, 4, 1, 5}, + nil, + }, + { + 14, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 5}, + {Name: "layout", I: 1}, + }, + }, + rnnInputBatchFirstNoH, + []float32{0.98516846, 0.9842066, 0.99648196, 0.999319, 0.95900106, 0.96534646, 0.9786028, 0.99493814, 0.9979183, 0.9448906}, + // shape [batch_size, sequence_length, num_directions, hidden_size] + []int{2, 4, 1, 5}, + nil, }, } @@ -149,6 +192,10 @@ func TestRNN(t *testing.T) { if err == nil { assert.Equal(t, test.expected, res[1].Data()) + + if test.expectedShape != nil { + assert.Equal(t, test.expectedShape, res[0].Shape()) + } } } } @@ -356,6 +403,44 @@ func rnnInputNoBNoH() []tensor.Tensor { } } +func rnnInputBatchFirst() []tensor.Tensor { + r := rand.New(rand.NewSource(13)) + + return []tensor.Tensor{ + // Input X: (batch_size, sequence_length, input_size). + ops.RandomFloat32TensorFixture(r, 2, 4, 3), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(r, 1, 5, 3), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(r, 1, 5, 5), + // Input B: not provided. + nil, + // Input sequence_lens: not supported + nil, + // Input initial_h: (batch_size, num_directions, hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(2, 1, 5)), 2, 1, 5), + } +} + +func rnnInputBatchFirstNoH() []tensor.Tensor { + r := rand.New(rand.NewSource(13)) + + return []tensor.Tensor{ + // Input X: (batch_size, sequence_length, input_size). + ops.RandomFloat32TensorFixture(r, 2, 4, 3), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(r, 1, 5, 3), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(r, 1, 5, 5), + // Input B: not provided. + nil, + // Input sequence_lens: not supported + nil, + // Input initial_h: (batch_size, num_directions, hidden_size) + nil, + } +} + func RNNOnnxNodeProtoFixture() *onnx.NodeProto { return &onnx.NodeProto{ Attribute: []*onnx.AttributeProto{ diff --git a/ops/rnn/versions.go b/ops/rnn/versions.go index dcea592..6733aa3 100644 --- a/ops/rnn/versions.go +++ b/ops/rnn/versions.go @@ -3,7 +3,8 @@ package rnn import "github.com/advancedclimatesystems/gonnx/ops" var rnnVersions = ops.OperatorVersions{ - 7: ops.NewOperatorConstructor(newRNN, 7, rnnTypeConstraints), + 7: ops.NewOperatorConstructor(newRNN, 7, rnnTypeConstraints), + 14: ops.NewOperatorConstructor(newRNN, 14, rnnTypeConstraints), } func GetVersions() ops.OperatorVersions { diff --git a/ops/sub/versions.go b/ops/sub/versions.go index 64e5db8..311abdf 100644 --- a/ops/sub/versions.go +++ b/ops/sub/versions.go @@ -1,12 +1,21 @@ package sub -import "github.com/advancedclimatesystems/gonnx/ops" +import ( + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) var subVersions = ops.OperatorVersions{ 7: ops.NewOperatorConstructor(newSub, 7, subTypeConstraints), 13: ops.NewOperatorConstructor(newSub, 13, subTypeConstraints), + 14: ops.NewOperatorConstructor(newSub, 14, sub14TypeConstraints), } func GetVersions() ops.OperatorVersions { return subVersions } + +var sub14TypeConstraints = [][]tensor.Dtype{ + ops.NumericTypes, + ops.NumericTypes, +} diff --git a/ops_test.go b/ops_test.go index d065549..f238c77 100644 --- a/ops_test.go +++ b/ops_test.go @@ -23,9 +23,6 @@ import ( // Another reason is that some tests require an opset version higher than we have currently // implemented, or lower, which we also haven't implemented yet. var ignoredTests = []string{ - "test_add_uint8", // Opset14 - "test_div_uint8", // Opset14 - "test_gru_batchwise", // Opset14 "test_logsoftmax_axis_1_expanded_ver18", // Opset18 "test_logsoftmax_example_1_expanded_ver18", // Opset18 "test_logsoftmax_negative_axis_expanded_ver18", // Opset18 @@ -33,8 +30,6 @@ var ignoredTests = []string{ "test_logsoftmax_default_axis_expanded_ver18", // Opset18 "test_logsoftmax_axis_0_expanded_ver18", // Opset18 "test_logsoftmax_axis_2_expanded_ver18", // Opset18 - "test_lstm_batchwise", // Opset14 - "test_mul_uint8", // Opset14 "test_reduce_max_empty_set", // Opset20 "test_reduce_max_do_not_keepdims_random", // Opset18 "test_reduce_max_keepdims_random", // Opset18 @@ -63,7 +58,6 @@ var ignoredTests = []string{ "test_reduce_mean_default_axes_keepdims_example", // Opset18 "test_reduce_mean_do_not_keepdims_example", // Opset18 "test_reduce_mean_keepdims_example", // Opset18 - "test_sub_uint8", // Opset14 "test_shape_clip_end", // Opset15 "test_shape_clip_start", // Opset15 "test_shape_end_1", // Opset15 @@ -364,6 +358,7 @@ var expectedTests = []string{ "test_acosh_example", "test_add", "test_add_bcast", + "test_add_uint8", "test_and_bcast3v1d", "test_and_bcast3v2d", "test_and_bcast4v2d", @@ -420,6 +415,7 @@ var expectedTests = []string{ "test_div", "test_div_bcast", "test_div_example", + "test_div_uint8", "test_equal", "test_equal_bcast", "test_erf", @@ -455,6 +451,7 @@ var expectedTests = []string{ "test_greater_equal_bcast", "test_greater_equal_bcast_expanded", "test_greater_equal_expanded", + "test_gru_batchwise", "test_gru_defaults", "test_gru_seq_length", "test_gru_with_initial_bias", @@ -471,6 +468,7 @@ var expectedTests = []string{ "test_logsoftmax_example_1", "test_logsoftmax_large_number", "test_logsoftmax_negative_axis", + "test_lstm_batchwise", "test_lstm_defaults", "test_lstm_with_initial_bias", "test_matmul_4d", @@ -479,6 +477,7 @@ var expectedTests = []string{ "test_mul", "test_mul_bcast", "test_mul_example", + "test_mul_uint8", "test_not_2d", "test_not_3d", "test_not_4d", @@ -537,6 +536,7 @@ var expectedTests = []string{ "test_sub", "test_sub_bcast", "test_sub_example", + "test_sub_uint8", "test_tan", "test_tan_example", "test_tanh", diff --git a/opset.go b/opset.go index f1f67ad..d62987e 100644 --- a/opset.go +++ b/opset.go @@ -67,7 +67,7 @@ import ( const ( MinSupportedOpset = 7 - MaxSupportedOpset = 13 + MaxSupportedOpset = 14 ) // Opset is a set of operators matching a certain opset version. @@ -134,7 +134,7 @@ var operators = map[string]ops.OperatorVersions{ "Transpose": transpose.GetVersions(), "Unsqueeze": unsqueeze.GetVersions(), "Xor": xor.GetVersions(), - "Where": where.GetVersions(), + "Where": where.GetVersions(), } // GetClosestOperatorVersion resolves, given a certain opset version, the operator version that is closest From 62b0c92a33a90e20459d65d761058acad25f19cf Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 30 Mar 2025 17:16:13 +0200 Subject: [PATCH 2/4] Added Trilu operator --- ops/trilu/trilu.go | 107 ++++++++++++++++++++++++++ ops/trilu/trilu_test.go | 161 ++++++++++++++++++++++++++++++++++++++++ ops/trilu/versions.go | 13 ++++ ops/utils.go | 2 +- ops_test.go | 52 ++++++++++--- opset.go | 2 + test.py | 17 +++++ 7 files changed, 341 insertions(+), 13 deletions(-) create mode 100644 ops/trilu/trilu.go create mode 100644 ops/trilu/trilu_test.go create mode 100644 ops/trilu/versions.go create mode 100644 test.py diff --git a/ops/trilu/trilu.go b/ops/trilu/trilu.go new file mode 100644 index 0000000..ee54468 --- /dev/null +++ b/ops/trilu/trilu.go @@ -0,0 +1,107 @@ +package trilu + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var triluTypeConstraints = [][]tensor.Dtype{ops.AllTypes, []tensor.Dtype{tensor.Int64}} + +// Trilu represents the ONNX trilu operator. +type Trilu struct { + ops.BaseOperator + + upper bool +} + +// newTrilu creates a new trilu operator. +func newTrilu(version int, typeConstraint [][]tensor.Dtype) ops.Operator { + return &Trilu{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 2, + typeConstraint, + "trilu", + ), + upper: true, // Default is true as per ONNX spec + } +} + +// Init initializes the trilu operator. +func (t *Trilu) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "upper": + t.upper = attr.GetI() == 1 + default: + return ops.ErrInvalidAttribute(attr.GetName(), t) + } + } + + return nil +} + +// Apply applies the trilu operator. +func (t *Trilu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + data := inputs[0] + + var k int + + if inputs[1] != nil { + var err error + + k, err = ops.AnyToInt(inputs[1].ScalarValue()) + if err != nil { + return nil, err + } + } + + rank := len(data.Shape()) + if rank < 2 { + return nil, ops.ErrInvalidInput("input tensor must be at least rank 2", t.BaseOperator) + } + + // Create output tensor with same shape and type as input + out, ok := data.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", out.Clone()) + } + + zeroVal, err := ops.GetValueAsTensorType(0.0, out.Dtype()) + if err != nil { + return nil, err + } + + it := out.Iterator() + it.Reset() + + for !it.Done() { + coords := it.Coord() + + row := coords[rank-2] + col := coords[rank-1] + + shouldZero := false + if t.upper { + shouldZero = col-row < k + } else { + shouldZero = col-row > k + } + + if shouldZero { + err = out.SetAt(zeroVal, coords...) + if err != nil { + return nil, err + } + } + + _, err := it.Next() + if err != nil { + return nil, err + } + } + + return []tensor.Tensor{out}, nil +} diff --git a/ops/trilu/trilu_test.go b/ops/trilu/trilu_test.go new file mode 100644 index 0000000..0cbae63 --- /dev/null +++ b/ops/trilu/trilu_test.go @@ -0,0 +1,161 @@ +package trilu + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestTriluInit(t *testing.T) { + // Test with upper = 1 (default, upper triangular) + attrs := makeUpperAttrProto(1) + op := Trilu{} + err := op.Init(attrs) + assert.NoError(t, err) + assert.Equal(t, true, op.upper) + + // Test with upper = 0 (lower triangular) + attrs = makeUpperAttrProto(0) + op = Trilu{} + err = op.Init(attrs) + assert.NoError(t, err) + assert.Equal(t, false, op.upper) +} + +func TestTriluInitDefault(t *testing.T) { + op, ok := newTrilu(14, triluTypeConstraints).(*Trilu) + assert.True(t, ok) + + err := op.Init(ops.EmptyNodeProto()) + assert.Nil(t, err) + assert.Equal(t, true, op.upper) // Default is true +} + +func TestTriluInitInvalidAttrName(t *testing.T) { + op := Trilu{BaseOperator: ops.NewBaseOperator(14, 1, 2, triluTypeConstraints, "trilu")} + err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "invalid"}}}) + assert.EqualError(t, err, "trilu v14 attribute error: invalid attribute invalid") +} + +func TestTrilu(t *testing.T) { + tests := []struct { + version int64 + attrs *onnx.NodeProto + data interface{} + dataShape []int + k int64 + expected interface{} + }{ + { + 14, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "upper", I: 1}, + }, + }, + []int64{0, 1, 2, 3, 4, 5, 6, 7, 8}, + []int{3, 3}, + 0, + []int{0, 1, 2, 0, 4, 5, 0, 0, 8}, + }, + { + 14, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "upper", I: 0}, + }, + }, + []int64{0, 1, 2, 3, 4, 5, 6, 7, 8}, + []int{3, 3}, + 0, + []int{0, 0, 0, 3, 0, 0, 6, 7, 0}, + }, + { + 14, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "upper", I: 1}, + }, + }, + []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{3, 4}, + 0, + []int{0, 1, 2, 0, 4, 5, 0, 0, 8, 0, 0, 11}, + }, + } + + for _, test := range tests { + op := triluVersions[test.version]() + err := op.Init(test.attrs) + assert.Nil(t, err) + + in := ops.TensorWithBackingFixture(test.data, test.dataShape...) + + k := tensor.New(tensor.FromScalar(test.k)) + + res, err := op.Apply([]tensor.Tensor{in, k}) + assert.Nil(t, err) + + if err != nil { + assert.Equal(t, test.expected, res[0].Data()) + } + } +} + +func TestInputValidationTrilu(t *testing.T) { + tests := []struct { + version int64 + inputs []tensor.Tensor + err error + }{ + { + 14, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + tensor.New(tensor.FromScalar(int64(0))), + }, + nil, + }, + { + 14, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + ops.ErrInvalidOptionalInputCount(3, trilu14BaseOpFixture()), + }, + { + 14, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + tensor.New(tensor.FromScalar(0)), + }, + ops.ErrInvalidInputType(0, "int", trilu14BaseOpFixture()), + }, + } + + for _, test := range tests { + trilu := triluVersions[test.version]() + validated, err := trilu.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} + +func trilu14BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(14, 1, 2, triluTypeConstraints, "trilu") +} + +func makeUpperAttrProto(upper int64) *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{{Name: "upper", I: upper}}, + } +} diff --git a/ops/trilu/versions.go b/ops/trilu/versions.go new file mode 100644 index 0000000..3d894b7 --- /dev/null +++ b/ops/trilu/versions.go @@ -0,0 +1,13 @@ +package trilu + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var triluVersions = ops.OperatorVersions{ + 14: ops.NewOperatorConstructor(newTrilu, 14, triluTypeConstraints), +} + +func GetVersions() ops.OperatorVersions { + return triluVersions +} diff --git a/ops/utils.go b/ops/utils.go index f41d911..895cb1b 100644 --- a/ops/utils.go +++ b/ops/utils.go @@ -228,7 +228,7 @@ func NElements(shp ...int) int { return nElem } -// PairwiseAssign essentially does pairwise t1 = t2 in place!. +// PairwiseAssign essentially does pairwise t1 = t2 in place. func PairwiseAssign(t1, t2 tensor.Tensor) (err error) { if !t1.Shape().Eq(t2.Shape()) { return ErrInvalidShape diff --git a/ops_test.go b/ops_test.go index f238c77..1db098f 100644 --- a/ops_test.go +++ b/ops_test.go @@ -99,6 +99,9 @@ var ignoredTests = []string{ "test_slice_neg_steps", // ONNX expects nil output, but we throw an error. "test_slice_neg", // ONNX expects nil output, but we throw an error. + "test_tril_zero", // Has a zero dimension in the middle (3, 0, 5) which is not supported in tensor.Tensor. + "test_triu_zero", // Has a zero dimension in the middle (3, 0, 5) which is not supported in tensor.Tensor. + "test_equal_string", // Unsupported datatype String. "test_equal_string_broadcast", // Unsupported datatype String. "test_cast_INT4_to_INT8", // Unsupported datatype INT4. @@ -206,13 +209,19 @@ func TestOps(t *testing.T) { func getTestCasesForOp(opName string) ([]*ONNXTestCase, error) { testOpName := strings.ToLower(opName) + + filterNames := []string{testOpName} + // Because the naming of the ONNX test cases are not fully consistent, we need // to map some operator names to insert some '_' in the filter. if mappedFilter, ok := opNameMap[testOpName]; ok { - testOpName = mappedFilter + filterNames = append(filterNames, mappedFilter...) } - opFilter := fmt.Sprintf("test_%v", testOpName) + opFilters := make([]string, len(filterNames)) + for i, filterName := range filterNames { + opFilters[i] = fmt.Sprintf("test_%v", filterName) + } testDir, err := os.Open("./test_data") if err != nil { @@ -227,7 +236,7 @@ func getTestCasesForOp(opName string) ([]*ONNXTestCase, error) { var tests []*ONNXTestCase for _, testFolder := range testFolders { - if shouldRunTest(testFolder, opFilter) { + if shouldRunTest(testFolder, opFilters) { testcase, err := getTestCase(fmt.Sprintf("./test_data/%v", testFolder)) if err != nil { return nil, err @@ -241,17 +250,19 @@ func getTestCasesForOp(opName string) ([]*ONNXTestCase, error) { return tests, nil } -func shouldRunTest(folder, opFilter string) bool { +func shouldRunTest(folder string, opFilters []string) bool { for _, ignoredTest := range ignoredTests { if folder == ignoredTest { return false } } - if strings.Contains(folder, opFilter) { - remaining := strings.ReplaceAll(folder, opFilter, "") - if len(remaining) == 0 || remaining[:1] == "_" { - return true + for _, opFilter := range opFilters { + if strings.Contains(folder, opFilter) { + remaining := strings.ReplaceAll(folder, opFilter, "") + if len(remaining) == 0 || remaining[:1] == "_" { + return true + } } } @@ -548,6 +559,22 @@ var expectedTests = []string{ "test_transpose_all_permutations_4", "test_transpose_all_permutations_5", "test_transpose_default", + "test_tril", + "test_tril_neg", + "test_tril_one_row_neg", + "test_tril_out_neg", + "test_tril_out_pos", + "test_tril_pos", + "test_tril_square", + "test_tril_square_neg", + "test_triu", + "test_triu_neg", + "test_triu_one_row", + "test_triu_out_neg_out", + "test_triu_out_pos", + "test_triu_pos", + "test_triu_square", + "test_triu_square_neg", "test_unsqueeze_axis_0", "test_unsqueeze_axis_1", "test_unsqueeze_axis_2", @@ -564,8 +591,9 @@ var expectedTests = []string{ "test_xor_bcast4v4d", } -var opNameMap = map[string]string{ - "reducemax": "reduce_max", - "reducemin": "reduce_min", - "reducemean": "reduce_mean", +var opNameMap = map[string][]string{ + "reducemax": []string{"reduce_max"}, + "reducemin": []string{"reduce_min"}, + "reducemean": []string{"reduce_mean"}, + "trilu": []string{"tril", "triu"}, } diff --git a/opset.go b/opset.go index d62987e..16deb4c 100644 --- a/opset.go +++ b/opset.go @@ -60,6 +60,7 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/tan" "github.com/advancedclimatesystems/gonnx/ops/tanh" "github.com/advancedclimatesystems/gonnx/ops/transpose" + "github.com/advancedclimatesystems/gonnx/ops/trilu" "github.com/advancedclimatesystems/gonnx/ops/unsqueeze" "github.com/advancedclimatesystems/gonnx/ops/where" "github.com/advancedclimatesystems/gonnx/ops/xor" @@ -132,6 +133,7 @@ var operators = map[string]ops.OperatorVersions{ "Tan": tan.GetVersions(), "Tanh": tanh.GetVersions(), "Transpose": transpose.GetVersions(), + "Trilu": trilu.GetVersions(), "Unsqueeze": unsqueeze.GetVersions(), "Xor": xor.GetVersions(), "Where": where.GetVersions(), diff --git a/test.py b/test.py new file mode 100644 index 0000000..2dacf02 --- /dev/null +++ b/test.py @@ -0,0 +1,17 @@ +import onnx +from onnx import numpy_helper + +def load_tensor_proto(pb_path): + with open(pb_path, 'rb') as f: + tensor_proto = onnx.TensorProto() + tensor_proto.ParseFromString(f.read()) + tensor_np = numpy_helper.to_array(tensor_proto) + return tensor_np + +path = "./test_data/test_tril_zero/test_data_set_0/" +# Load your input tensor(s) +input_tensor_1 = load_tensor_proto(path + "input_0.pb") +input_tensor_2 = load_tensor_proto(path + "input_1.pb") + +print(input_tensor_1.shape, input_tensor_2.shape) +import pdb;pdb.set_trace() From bd085bc008f19d27dd3288ad5f9c8dc74f6b5357 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 30 Mar 2025 17:18:49 +0200 Subject: [PATCH 3/4] Add Identity 14 version --- README.md | 2 +- ops/identity/versions.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9d934b6..2241f5c 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ The goal of this package is to provide an easy way of running ONNX models in Go. is intended for inference usage of ONNX models. The package can be used to load an `.onnx` file and perform inference using the model described by this file. -Currently, we are implementing ONNX operation set 13, and we plan to add all opsets following this +Currently, ONNX opset versions 7 up until 14 are supported, and we plan to add all opsets following this one as well. Feel free to contribute by implementing operators! ## Getting started diff --git a/ops/identity/versions.go b/ops/identity/versions.go index 96ad876..94daedf 100644 --- a/ops/identity/versions.go +++ b/ops/identity/versions.go @@ -6,6 +6,7 @@ import ( var identityVersions = ops.OperatorVersions{ 13: ops.NewOperatorConstructor(newIdentity, 13, identityTypeConstraints), + 14: ops.NewOperatorConstructor(newIdentity, 14, identityTypeConstraints), } func GetVersions() ops.OperatorVersions { From 29eec646e5a48b3a9651d5597c52ae92b940e453 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 30 Mar 2025 17:49:34 +0200 Subject: [PATCH 4/4] Clone before reshaping --- ops/gru/gru.go | 36 ++++++++++++------------------- ops/lstm/lstm.go | 48 ++++++++++++++++++++---------------------- ops/recurrent_utils.go | 22 +++++++++++++++++++ ops/rnn/rnn.go | 36 ++++++++++++------------------- 4 files changed, 71 insertions(+), 71 deletions(-) diff --git a/ops/gru/gru.go b/ops/gru/gru.go index f73d7d1..ae58538 100644 --- a/ops/gru/gru.go +++ b/ops/gru/gru.go @@ -102,27 +102,9 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, ops.ErrUnsupportedInput("sequence lens", g.BaseOperator) } - X := inputs[0] - - var seqLength int - - var batchSize int - - // The 'layout' parameter handles whether or not the batch dimension comes - // first in the tensor. If this is the case, we reshape it here in - // in the beginning of the operation, and reverse it at the end of the operation. - if g.layout == 1 { - seqLength = X.Shape()[1] - batchSize = X.Shape()[0] - inputSize := X.Shape()[2] - - err := X.Reshape(seqLength, batchSize, inputSize) - if err != nil { - return nil, err - } - } else { - seqLength = X.Shape()[0] - batchSize = X.Shape()[1] + X, seqLength, batchSize, err := ops.ReshapeInputTensorBasedOnLayout(inputs[0], g.layout) + if err != nil { + return nil, err } Wz, Wr, Wh, err := g.getWeights(inputs[1]) @@ -147,13 +129,21 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - prevH := inputs[5] - if prevH == nil { + var prevH tensor.Tensor + + if inputs[5] == nil { if g.layout == 1 { prevH = ops.ZeroTensor(batchSize, 1, g.hiddenSize) } else { prevH = ops.ZeroTensor(1, batchSize, g.hiddenSize) } + } else { + var ok bool + + prevH, ok = inputs[5].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[5].Clone()) + } } // If layout is 1, this means batch size comes as first dimension, and diff --git a/ops/lstm/lstm.go b/ops/lstm/lstm.go index 8bc80c8..f9867ce 100644 --- a/ops/lstm/lstm.go +++ b/ops/lstm/lstm.go @@ -105,27 +105,9 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, ops.ErrUnsupportedInput("sequence_lens", l.BaseOperator) } - X := inputs[0] - - var seqLength int - - var batchSize int - - // The 'layout' parameter handles whether or not the batch dimension comes - // first in the tensor. If this is the case, we reshape it here in - // in the beginning of the operation, and reverse it at the end of the operation. - if l.layout == 1 { - seqLength = X.Shape()[1] - batchSize = X.Shape()[0] - inputSize := X.Shape()[2] - - err := X.Reshape(seqLength, batchSize, inputSize) - if err != nil { - return nil, err - } - } else { - seqLength = X.Shape()[0] - batchSize = X.Shape()[1] + X, seqLength, batchSize, err := ops.ReshapeInputTensorBasedOnLayout(inputs[0], l.layout) + if err != nil { + return nil, err } Wi, Wo, Wf, Wc, err := l.getWeights(inputs[1]) @@ -150,22 +132,38 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - Ht := inputs[5] - if Ht == nil { + var Ht tensor.Tensor + + if inputs[5] == nil { if l.layout == 1 { Ht = ops.ZeroTensor(batchSize, 1, l.hiddenSize) } else { Ht = ops.ZeroTensor(1, batchSize, l.hiddenSize) } + } else { + var ok bool + + Ht, ok = inputs[5].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[5].Clone()) + } } - Ct := inputs[6] - if Ct == nil { + var Ct tensor.Tensor + + if inputs[6] == nil { if l.layout == 1 { Ct = ops.ZeroTensor(batchSize, 1, l.hiddenSize) } else { Ct = ops.ZeroTensor(1, batchSize, l.hiddenSize) } + } else { + var ok bool + + Ct, ok = inputs[6].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[6].Clone()) + } } var Pi, Po, Pf tensor.Tensor diff --git a/ops/recurrent_utils.go b/ops/recurrent_utils.go index 8a86424..7414cec 100644 --- a/ops/recurrent_utils.go +++ b/ops/recurrent_utils.go @@ -72,3 +72,25 @@ func OnesTensor(t tensor.Tensor) tensor.Tensor { tensor.WithBacking(Ones(NElements(t.Shape()...))), ) } + +func ReshapeInputTensorBasedOnLayout(X tensor.Tensor, layout int) (tensor.Tensor, int, int, error) { + if layout == 1 { + newX, ok := X.Clone().(tensor.Tensor) + if !ok { + return nil, 0, 0, ErrTypeAssert("tensor.Tensor", X.Clone()) + } + + seqLength := X.Shape()[1] + batchSize := X.Shape()[0] + inputSize := X.Shape()[2] + + err := newX.Reshape(seqLength, batchSize, inputSize) + if err != nil { + return nil, 0, 0, err + } + + return newX, seqLength, batchSize, nil + } + + return X, X.Shape()[0], X.Shape()[1], nil +} diff --git a/ops/rnn/rnn.go b/ops/rnn/rnn.go index 41fbc67..74708b1 100644 --- a/ops/rnn/rnn.go +++ b/ops/rnn/rnn.go @@ -94,27 +94,9 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, ops.ErrUnsupportedInput("sequence lens", r.BaseOperator) } - X := inputs[0] - - var seqLength int - - var batchSize int - - // The 'layout' parameter handles whether or not the batch dimension comes - // first in the tensor. If this is the case, we reshape it here in - // in the beginning of the operation, and reverse it at the end of the operation. - if r.layout == 1 { - seqLength = X.Shape()[1] - batchSize = X.Shape()[0] - inputSize := X.Shape()[2] - - err := X.Reshape(seqLength, batchSize, inputSize) - if err != nil { - return nil, err - } - } else { - seqLength = X.Shape()[0] - batchSize = X.Shape()[1] + X, seqLength, batchSize, err := ops.ReshapeInputTensorBasedOnLayout(inputs[0], r.layout) + if err != nil { + return nil, err } Wi, err := r.getWeights(inputs[1]) @@ -139,13 +121,21 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - Ht := inputs[5] - if Ht == nil { + var Ht tensor.Tensor + + if inputs[5] == nil { if r.layout == 1 { Ht = ops.ZeroTensor(batchSize, 1, r.hiddenSize) } else { Ht = ops.ZeroTensor(1, batchSize, r.hiddenSize) } + } else { + var ok bool + + Ht, ok = inputs[5].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[5].Clone()) + } } // If layout is 1, this means batch size comes as first dimension, and