forked from gorgonia/tensor
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtensor.go
170 lines (146 loc) · 3.89 KB
/
tensor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
// Package tensor is a package that provides efficient, generic n-dimensional arrays in Go.
// Also in this package are functions and methods that are used commonly in arithmetic, comparison and linear algebra operations.
package tensor // import "gorgonia.org/tensor"
import (
"encoding/gob"
"fmt"
"io"
"github.com/pkg/errors"
)
var (
_ Tensor = &Dense{}
_ Tensor = &CS{}
_ View = &Dense{}
)
func init() {
gob.Register(&Dense{})
gob.Register(&CS{})
}
// Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor.
// It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors.
type Tensor interface {
// info about the ndarray
Shape() Shape
Strides() []int
Dtype() Dtype
Dims() int
Size() int
DataSize() int
// Data access related
RequiresIterator() bool
Iterator() Iterator
DataOrder() DataOrder
// ops
Slicer
At(...int) (interface{}, error)
SetAt(v interface{}, coord ...int) error
Reshape(...int) error
T(axes ...int) error
UT()
Transpose() error // Transpose actually moves the data
Apply(fn interface{}, opts ...FuncOpt) (Tensor, error)
// data related interface
Zeroer
MemSetter
Dataer
Eq
Cloner
// type overloading methods
IsScalar() bool
ScalarValue() interface{}
// engine/memory related stuff
// all Tensors should be able to be expressed of as a slab of memory
// Note: the size of each element can be acquired by T.Dtype().Size()
Memory // Tensors all implement Memory
Engine() Engine // Engine can be nil
IsNativelyAccessible() bool // Can Go access the memory
IsManuallyManaged() bool // Must Go manage the memory
// formatters
fmt.Formatter
fmt.Stringer
// all Tensors are serializable to these formats
WriteNpy(io.Writer) error
ReadNpy(io.Reader) error
gob.GobEncoder
gob.GobDecoder
standardEngine() standardEngine
headerer
arrayer
}
// New creates a new Dense Tensor. For sparse arrays use their relevant construction function
func New(opts ...ConsOpt) *Dense {
d := borrowDense()
for _, opt := range opts {
opt(d)
}
d.fix()
if err := d.sanity(); err != nil {
panic(err)
}
return d
}
func assertDense(t Tensor) (*Dense, error) {
if t == nil {
return nil, errors.New("nil is not a *Dense")
}
if retVal, ok := t.(*Dense); ok {
return retVal, nil
}
if retVal, ok := t.(Densor); ok {
return retVal.Dense(), nil
}
return nil, errors.Errorf("%T is not *Dense", t)
}
func getDenseTensor(t Tensor) (DenseTensor, error) {
switch tt := t.(type) {
case DenseTensor:
return tt, nil
case Densor:
return tt.Dense(), nil
default:
return nil, errors.Errorf("Tensor %T is not a DenseTensor", t)
}
}
// getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float
func getFloatDenseTensor(t Tensor) (retVal DenseTensor, err error) {
if t == nil {
return
}
if err = typeclassCheck(t.Dtype(), floatTypes); err != nil {
err = errors.Wrapf(err, "getFloatDense only handles floats. Got %v instead", t.Dtype())
return
}
if retVal, err = getDenseTensor(t); err != nil {
err = errors.Wrapf(err, opFail, "getFloatDense")
return
}
if retVal == nil {
return
}
return
}
// getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float
func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) {
if t == nil {
return
}
if err = typeclassCheck(t.Dtype(), floatcmplxTypes); err != nil {
err = errors.Wrapf(err, "getFloatDense only handles floats and complex. Got %v instead", t.Dtype())
return
}
if retVal, err = getDenseTensor(t); err != nil {
err = errors.Wrapf(err, opFail, "getFloatDense")
return
}
if retVal == nil {
return
}
return
}
func sliceDense(t *Dense, slices ...Slice) (retVal *Dense, err error) {
var sliced Tensor
if sliced, err = t.Slice(slices...); err != nil {
return nil, err
}
return sliced.(*Dense), nil
}