-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathwhile_loop_ext.cpp
110 lines (81 loc) · 2.46 KB
/
while_loop_ext.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include <nanobind/stl/pair.h>
#include <drjit/while_loop.h>
#include <drjit-core/python.h>
#include <drjit/packet.h>
#include <drjit/random.h>
namespace nb = nanobind;
namespace dr = drjit;
using namespace nb::literals;
template <typename UInt> drjit::tuple<UInt, UInt> simple_loop() {
using Bool = dr::mask_t<UInt>;
UInt i = dr::arange<UInt>(7),
j = 0;
drjit::tie(i, j) = dr::while_loop(
dr::make_tuple(i, j),
[](const UInt &i, const UInt &) {
return i < 5;
},
[](UInt &i, UInt &j) {
i += 1;
j = i + 4;
}
);
return { i, j };
}
template <typename T>
struct Sampler {
Sampler(size_t size) : rng(size) { }
T next() { return rng.next_float32(); }
void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) const {
traverse_1_fn_ro(rng, payload, fn);
}
void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) {
traverse_1_fn_rw(rng, payload, fn);
}
dr::PCG32<dr::uint64_array_t<T>> rng;
};
template <typename UInt> UInt loop_with_rng() {
using Bool = dr::mask_t<UInt>;
using Sampler = Sampler<dr::float32_array_t<UInt>>;
auto s1 = dr::make_unique<Sampler>(3);
auto s2 = dr::make_unique<Sampler>(3);
UInt i = dr::arange<UInt>(3);
Sampler *s = s1.get();
drjit::tie(i, s) = dr::while_loop(
dr::make_tuple(i, s),
[](const UInt &i, const Sampler *) {
return i < 3;
},
[](UInt &i, Sampler *s) {
i += 1;
s->rng.next_float32();
}
);
return UInt(s1->rng - s2->rng);
}
bool packet_loop() {
using Float = dr::Packet<float, 16>;
using RNG = dr::PCG32<Float>;
RNG a, b, c;
for (int i = 0; i < 1000; ++i)
a.next_float32();
b += 1000;
return dr::all(a.next_float32() == b.next_float32()) && dr::all((b - c) == 1001);
}
template <JitBackend Backend> void bind(nb::module_ &m) {
using UInt = dr::DiffArray<Backend, uint32_t>;
m.def("simple_loop", &simple_loop<UInt>);
m.def("loop_with_rng", &loop_with_rng<UInt>);
}
NB_MODULE(while_loop_ext, m) {
#if defined(DRJIT_ENABLE_LLVM)
nb::module_ llvm = m.def_submodule("llvm");
bind<JitBackend::LLVM>(llvm);
#endif
#if defined(DRJIT_ENABLE_CUDA)
nb::module_ cuda = m.def_submodule("cuda");
bind<JitBackend::CUDA>(cuda);
#endif
m.def("scalar_loop", &simple_loop<uint32_t>);
m.def("packet_loop", &packet_loop);
}