forked from gorgonia/tensor
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexample_extension_test.go
101 lines (87 loc) · 2.2 KB
/
example_extension_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package tensor_test
import (
//"errors"
"fmt"
"reflect"
"github.com/pkg/errors"
"github.com/pdevine/tensor"
)
// In this example, we want to create and handle a tensor of *MyType
// First, define MyType
// MyType is defined
type MyType struct {
x, y int
}
func (T MyType) Format(s fmt.State, c rune) { fmt.Fprintf(s, "(%d, %d)", T.x, T.y) }
// MyDtype this the dtype of MyType. This value is populated in the init() function below
var MyDtype tensor.Dtype
// MyEngine supports additions of MyType, as well as other Dtypes
type MyEngine struct {
tensor.StdEng
}
// For simplicity's sake, we'd only want to handle MyType-MyType or MyType-Int interactions
// Also, we only expect Dense tensors
// You're of course free to define your own rules
// Add adds two tensors
func (e MyEngine) Add(a, b tensor.Tensor, opts ...tensor.FuncOpt) (retVal tensor.Tensor, err error) {
switch a.Dtype() {
case MyDtype:
switch b.Dtype() {
case MyDtype:
data := a.Data().([]*MyType)
datb := b.Data().([]*MyType)
for i, v := range data {
v.x += datb[i].x
v.y += datb[i].y
}
return a, nil
case tensor.Int:
data := a.Data().([]*MyType)
datb := b.Data().([]int)
for i, v := range data {
v.x += datb[i]
v.y += datb[i]
}
return a, nil
}
case tensor.Int:
switch b.Dtype() {
case MyDtype:
data := a.Data().([]int)
datb := b.Data().([]*MyType)
for i, v := range datb {
v.x += data[i]
v.y += data[i]
}
default:
return e.StdEng.Add(a, b, opts...)
}
default:
return e.StdEng.Add(a, b, opts...)
}
return nil, errors.New("Unreachable")
}
func init() {
MyDtype = tensor.Dtype{reflect.TypeOf(&MyType{})}
}
func Example_extension() {
T := tensor.New(tensor.WithEngine(MyEngine{}),
tensor.WithShape(2, 2),
tensor.WithBacking([]*MyType{
&MyType{0, 0}, &MyType{0, 1},
&MyType{1, 0}, &MyType{1, 1},
}))
ones := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]int{1, 1, 1, 1}), tensor.WithEngine(MyEngine{}))
T2, _ := T.Add(ones)
fmt.Printf("T:\n%+v", T)
fmt.Printf("T2:\n%+v", T2)
// output:
//T:
// Matrix (2, 2) [2 1]
// ⎡(1, 1) (1, 2)⎤
// ⎣(2, 1) (2, 2)⎦
// T2:
// Matrix (2, 2) [2 1]
// ⎡(1, 1) (1, 2)⎤
// ⎣(2, 1) (2, 2)⎦
}