forked from gorgonia/tensor
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtype_test.go
66 lines (58 loc) · 1.66 KB
/
type_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package tensor
import (
"reflect"
"testing"
)
type Float16 uint16
func TestRegisterType(t *testing.T) {
dt := Dtype{reflect.TypeOf(Float16(0))}
RegisterFloat(dt)
if err := typeclassCheck(dt, floatTypes); err != nil {
t.Errorf("Expected %v to be in floatTypes: %v", dt, err)
}
if err := typeclassCheck(dt, numberTypes); err != nil {
t.Errorf("Expected %v to be in numberTypes: %v", dt, err)
}
if err := typeclassCheck(dt, ordTypes); err != nil {
t.Errorf("Expected %v to be in ordTypes: %v", dt, err)
}
if err := typeclassCheck(dt, eqTypes); err != nil {
t.Errorf("Expected %v to be in eqTypes: %v", dt, err)
}
}
func TestDtypeConversions(t *testing.T) {
for k, v := range reverseNumpyDtypes {
if npdt, err := v.numpyDtype(); npdt != k {
t.Errorf("Expected %v to return numpy dtype of %q. Got %q instead", v, k, npdt)
} else if err != nil {
t.Errorf("Error: %v", err)
}
}
dt := Dtype{reflect.TypeOf(Float16(0))}
if _, err := dt.numpyDtype(); err == nil {
t.Errorf("Expected an error when passing in type unknown to np")
}
for k, v := range numpyDtypes {
if dt, err := fromNumpyDtype(v); dt != k {
// special cases
if Int.Size() == 4 && v == "i4" && dt == Int {
continue
}
if Int.Size() == 8 && v == "i8" && dt == Int {
continue
}
if Uint.Size() == 4 && v == "u4" && dt == Uint {
continue
}
if Uint.Size() == 8 && v == "u8" && dt == Uint {
continue
}
t.Errorf("Expected %q to return %v. Got %v instead", v, k, dt)
} else if err != nil {
t.Errorf("Error: %v", err)
}
}
if _, err := fromNumpyDtype("EDIUH"); err == nil {
t.Error("Expected error when nonsense is passed into fromNumpyDtype")
}
}