-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathtest_conversion.py
147 lines (119 loc) · 5.03 KB
/
test_conversion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import drjit as dr
import pytest
def skip_tf_if_not_available(t):
# Skip overall if TF is not available
pytest.importorskip("tensorflow.config")
# Skip CUDA backend roundtrip if TensorFlow doesn't support
# CUDA, e.g. on native Windows since version 2.11.
from tensorflow.config import list_physical_devices
if (dr.backend_v(t) == dr.JitBackend.CUDA) and not list_physical_devices("GPU"):
pytest.skip("TensorFlow didn't detect a CUDA device, skipping.")
# Test conversions to/from numpy (tensors & dynamic arrays)
@pytest.test_arrays('is_tensor, -bool, -float16')
def test01_roundtrip_dynamic_numpy(t):
pytest.importorskip("numpy")
a = t([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
roundtrip = t(a.numpy())
assert roundtrip.shape == (2, 2, 3) and roundtrip.shape == a.shape
assert dr.all(a == roundtrip, axis=None)
flat_t = type(a.array)
roundtrip = flat_t(a.array.numpy())
assert dr.all(a.array == roundtrip, axis=None)
# Test conversions to/from numpy (vectors)
@pytest.test_arrays('vector, shape=(3, *), -bool, -float16')
def test02_roundtrip_vector_numpy(t):
pytest.importorskip("numpy")
a = t([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
roundtrip = t(a.numpy())
assert roundtrip.shape == (3, 3) and roundtrip.shape == a.shape
assert dr.all(a == roundtrip, axis=None)
# Test conversions to/from torch (tensors & dynamic array)
@pytest.test_arrays('tensor, -bool, -float16, -uint64, -uint32')
def test03_roundtrip_dynamic_torch(t):
pytest.importorskip("torch")
a = t([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
roundtrip = t(a.torch())
assert roundtrip.shape == (2, 2, 3) and roundtrip.shape == a.shape
assert dr.all(a == roundtrip, axis=None)
flat_t = type(a.array)
roundtrip = flat_t(a.array.numpy())
assert dr.all(a.array == roundtrip, axis=None)
# Test conversions to/from torch (vectors)
@pytest.test_arrays('vector, shape=(3, *), -bool, -uint64, -uint32')
def test04_roundtrip_vector_torch(t):
pytest.importorskip("torch")
a = t([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
roundtrip = t(a.torch())
assert roundtrip.shape == (3, 3) and roundtrip.shape == a.shape
assert dr.all(a == roundtrip, axis=None)
# Test conversions to/from tf (tensors & dynamic array)
@pytest.test_arrays('tensor, -bool, -float16')
def test05_roundtrip_dynamic_tf(t):
skip_tf_if_not_available(t)
a = t([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
roundtrip = t(a.tf())
assert roundtrip.shape == (2, 2, 3) and roundtrip.shape == a.shape
assert dr.all(a == roundtrip, axis=None)
flat_t = type(a.array)
roundtrip = flat_t(a.array.numpy())
assert dr.all(a.array == roundtrip, axis=None)
# Test conversions to/from tf (vectors)
@pytest.test_arrays('vector, shape=(3, *), -bool, -float16')
def test06_roundtrip_vector_tf(t):
skip_tf_if_not_available(t)
a = t([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
roundtrip = t(a.tf())
assert roundtrip.shape == (3, 3) and roundtrip.shape == a.shape
assert dr.all(a == roundtrip, axis=None)
# Test conversions to/from jax (tensors & dynamic array)
@pytest.test_arrays('tensor, -bool, -uint64, -int64, -float64')
def test07_roundtrip_dynamic_jax(t):
pytest.importorskip("jax")
a = t([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
roundtrip = t(a.jax())
assert roundtrip.shape == (2, 2, 3) and roundtrip.shape == a.shape
assert dr.all(a == roundtrip, axis=None)
flat_t = type(a.array)
roundtrip = flat_t(a.array.numpy())
assert dr.all(a.array == roundtrip, axis=None)
# Test conversions to/from jax(vectors)
@pytest.test_arrays('vector, shape=(3, *), -bool, -uint64, -int64, -float64')
def test08_roundtrip_vector_jax(t):
pytest.importorskip("jax")
a = t([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
roundtrip = t(a.jax())
assert roundtrip.shape == (3, 3) and roundtrip.shape == a.shape
assert dr.all(a == roundtrip, axis=None)
# Test inplace modifications from numpy (tensors & dynamic array)
@pytest.test_arrays('tensor, -bool, -float16, -uint64, -uint32')
def test09_inplace_numpy(t):
pytest.importorskip("numpy")
a = dr.zeros(t, shape=(3, 3, 3))
x = a.numpy()
x[0,0,0] = 1
backend = dr.backend_v(a)
if backend == dr.JitBackend.LLVM or backend == dr.JitBackend.Invalid:
assert a[0,0,0] == x[0,0,0]
elif backend == dr.JitBackend.CUDA:
assert a[0,0,0] == 0
assert x[0,0,0] == 1
# Test inplace modifications from torch (tensors & dynamic array)
@pytest.test_arrays('tensor, -bool, -float16, -uint64, -uint32')
def test10_inplace_torch(t):
pytest.importorskip("torch")
a = dr.empty(t, shape=(3, 3, 3))
x = a.torch()
x[0,0,0] = 1
assert a[0,0,0] == x[0,0,0]
# Test AD index preservation after conversion
@pytest.test_arrays('is_diff,float32,shape=(*)')
def test11_conversion_ad(t):
pytest.importorskip("numpy")
x = dr.ones(t)
dr.enable_grad(x)
i = x.index_ad
with dr.suspend_grad():
y = x.numpy()
assert dr.grad_enabled(x)
assert i != 0
assert i == x.index_ad