-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathcustom_type_ext.cpp
102 lines (78 loc) · 2.84 KB
/
custom_type_ext.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <drjit/python.h>
#include <drjit/autodiff.h>
#include <drjit/packet.h>
namespace nb = nanobind;
namespace dr = drjit;
using namespace nb::literals;
template <typename Value_>
struct Color : dr::StaticArrayImpl<Value_, 3, false, Color<Value_>> {
using Base = dr::StaticArrayImpl<Value_, 3, false, Color<Value_>>;
/// Helper alias used to implement type promotion rules
template <typename T> using ReplaceValue = Color<T>;
using ArrayType = Color;
using MaskType = dr::Mask<Value_, 3>;
decltype(auto) r() const { return Base::x(); }
decltype(auto) r() { return Base::x(); }
decltype(auto) g() const { return Base::y(); }
decltype(auto) g() { return Base::y(); }
decltype(auto) b() const { return Base::z(); }
decltype(auto) b() { return Base::z(); }
DRJIT_ARRAY_IMPORT(Color, Base)
};
template <typename Value>
struct CustomHolder {
public:
CustomHolder() {}
CustomHolder(const Value &v) : m_value(v) {}
Value &value() { return m_value; }
bool schedule_force_() { return dr::detail::schedule_force(m_value); }
private:
Value m_value;
};
template <JitBackend Backend> void bind(nb::module_ &m) {
dr::ArrayBinding b;
using Float = dr::DiffArray<Backend, float>;
using Color3f = Color<Float>;
dr::bind_array_t<Color3f>(b, m, "Color3f")
.def_prop_rw("r",
[](Color3f &c) -> Float & { return c.r(); },
[](Color3f &c, Float &value) { c.r() = value; })
.def_prop_rw("g",
[](Color3f &c) -> Float & { return c.g(); },
[](Color3f &c, Float &value) { c.g() = value; })
.def_prop_rw("b",
[](Color3f &c) -> Float & { return c.b(); },
[](Color3f &c, Float &value) { c.b() = value; });
using CustomFloatHolder = CustomHolder<Float>;
nb::class_<CustomFloatHolder>(m, "CustomFloatHolder")
.def(nb::init<Float>())
.def("value", &CustomFloatHolder::value, nanobind::rv_policy::reference);
m.def("cpp_make_opaque",
[](CustomFloatHolder &holder) { dr::make_opaque(holder); }
);
}
NB_MODULE(custom_type_ext, m) {
#if defined(DRJIT_ENABLE_LLVM)
nb::module_ llvm = m.def_submodule("llvm");
bind<JitBackend::LLVM>(llvm);
#endif
#if defined(DRJIT_ENABLE_CUDA)
nb::module_ cuda = m.def_submodule("cuda");
bind<JitBackend::CUDA>(cuda);
#endif
// Tests: DRJIT_STRUCT, traversal mechanism, array/struct stringification
m.def("struct_to_string", []{
using Float = dr::Packet<float, 4>;
using Array3f = dr::Array<Float, 3>;
struct Ray {
Float time;
Array3f o, d;
bool has_ray_differentials;
DRJIT_STRUCT(Ray, time, o, d, has_ray_differentials)
};
Ray x = dr::zeros<Ray>();
x.has_ray_differentials = true;
x.o.y()[2] = 3;
return dr::string(x);
});
}