-
Notifications
You must be signed in to change notification settings - Fork 524
/
Copy pathpybindings.cpp
407 lines (365 loc) · 11.7 KB
/
pybindings.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include <stack>
#include "executorch/extension/pytree/pytree.h"
namespace py = pybind11;
namespace executorch {
namespace extension {
namespace pytree {
namespace {
struct PyAux {
py::object custom_type_context;
};
using PyTreeSpec = TreeSpec<PyAux>;
class PyTypeRegistry {
public:
struct PyTypeReg {
explicit PyTypeReg(Kind k) : kind(k) {}
Kind kind;
// for custom types
py::object type;
// function type: object -> (children, spec_data)
py::function flatten;
// function type: (children, spec_data) -> object
py::function unflatten;
};
static const PyTypeReg* get_by_str(const std::string& pytype) {
auto* registry = instance();
auto it = registry->regs_.find(pytype);
return it == registry->regs_.end() ? nullptr : it->second.get();
}
static const PyTypeReg* get_by_type(py::handle pytype) {
return get_by_str(py::str(pytype));
}
static void register_custom_type(
py::object type,
py::function flatten,
py::function unflatten) {
auto* registry = instance();
auto reg = std::make_unique<PyTypeReg>(Kind::Custom);
reg->type = type;
reg->flatten = std::move(flatten);
reg->unflatten = std::move(unflatten);
std::string pytype_str = py::str(type);
auto it = registry->regs_.emplace(pytype_str, std::move(reg));
if (!it.second) {
assert(false);
}
}
private:
static PyTypeRegistry* instance() {
static auto* registry_instance = []() -> PyTypeRegistry* {
auto* registry = new PyTypeRegistry;
auto add_pytype_reg = [&](const std::string& pytype, Kind kind) {
registry->regs_.emplace(pytype, std::make_unique<PyTypeReg>(kind));
};
add_pytype_reg("<class 'tuple'>", Kind::Tuple);
add_pytype_reg("<class 'list'>", Kind::List);
add_pytype_reg("<class 'dict'>", Kind::Dict);
return registry;
}();
return registry_instance;
}
std::unordered_map<std::string, std::unique_ptr<PyTypeReg>> regs_;
};
class PyTree {
PyTreeSpec spec_;
static void flatten_internal(
py::handle x,
std::vector<py::object>& leaves,
PyTreeSpec& s) {
const auto* reg = PyTypeRegistry::get_by_type(x.get_type());
const auto kind = [®, &x]() {
if (reg) {
return reg->kind;
}
if (py::isinstance<py::tuple>(x) && py::hasattr(x, "_fields")) {
return Kind::NamedTuple;
}
return Kind::Leaf;
}();
switch (kind) {
case Kind::List: {
const size_t n = PyList_GET_SIZE(x.ptr());
s = PyTreeSpec(Kind::List, n);
for (size_t i = 0; i < n; ++i) {
flatten_internal(PyList_GET_ITEM(x.ptr(), i), leaves, s[i]);
}
break;
}
case Kind::Tuple: {
const size_t n = PyTuple_GET_SIZE(x.ptr());
s = PyTreeSpec(Kind::Tuple, n);
for (size_t i = 0; i < n; ++i) {
flatten_internal(PyTuple_GET_ITEM(x.ptr(), i), leaves, s[i]);
}
break;
}
case Kind::NamedTuple: {
py::tuple tuple = py::reinterpret_borrow<py::tuple>(x);
const size_t n = tuple.size();
s = PyTreeSpec(Kind::NamedTuple, n);
size_t i = 0;
for (py::handle entry : tuple) {
flatten_internal(entry, leaves, s[i++]);
}
break;
}
case Kind::Dict: {
py::dict dict = py::reinterpret_borrow<py::dict>(x);
py::list keys =
py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
const auto n = PyList_GET_SIZE(keys.ptr());
s = PyTreeSpec(Kind::Dict, n);
size_t i = 0;
for (py::handle key : keys) {
if (py::isinstance<py::str>(key)) {
s.key(i) = py::cast<std::string>(key);
} else if (py::isinstance<py::int_>(key)) {
s.key(i) = py::cast<int32_t>(key);
} else {
throw std::runtime_error(
std::string(
"invalid key in pytree dict; must be int or string but got ") +
std::string(py::str(key.get_type())));
}
flatten_internal(dict[key], leaves, s[i]);
i++;
}
break;
}
case Kind::Custom: {
py::tuple out = py::cast<py::tuple>(reg->flatten(x));
if (out.size() != 2) {
assert(false);
}
py::list children = py::cast<py::list>(out[0]);
const size_t n = children.size();
s = PyTreeSpec(Kind::Custom, n);
s.handle->custom_type = py::str(x.get_type());
s.handle->custom_type_context = out[1];
size_t i = 0;
for (py::handle pychild : children) {
flatten_internal(pychild, leaves, s[i++]);
}
break;
}
case Kind::Leaf: {
s = PyTreeSpec(Kind::Leaf);
leaves.push_back(py::reinterpret_borrow<py::object>(x));
break;
}
case Kind::None:
[[fallthrough]];
default:
throw std::runtime_error(
std::string("invalid pytree kind ") + std::to_string(int(kind)) +
" in flatten_internal");
}
}
template <typename T>
py::object unflatten_internal(const PyTreeSpec& spec, T&& leaves_it) const {
switch (spec.kind()) {
case Kind::NamedTuple:
case Kind::Tuple: {
const size_t size = spec.size();
py::tuple tuple(size);
for (size_t i = 0; i < size; ++i) {
tuple[i] = unflatten_internal(spec[i], leaves_it);
}
return std::move(tuple);
}
case Kind::List: {
const size_t size = spec.size();
py::list list(size);
for (size_t i = 0; i < size; ++i) {
list[i] = unflatten_internal(spec[i], leaves_it);
}
return std::move(list);
}
case Kind::Custom: {
const auto& pytype_str = spec.handle->custom_type;
const auto* reg = PyTypeRegistry::get_by_str(pytype_str);
const size_t size = spec.size();
py::list list(size);
for (size_t i = 0; i < size; ++i) {
list[i] = unflatten_internal(spec[i], leaves_it);
}
py::object o = reg->unflatten(list, spec.handle->custom_type_context);
return o;
}
case Kind::Dict: {
const size_t size = spec.size();
py::dict dict;
for (size_t i = 0; i < size; ++i) {
auto& key = spec.key(i);
auto py_key = [&key]() -> py::handle {
switch (key.kind()) {
case Key::Kind::Int:
return py::cast(key.as_int()).release();
case Key::Kind::Str:
return py::cast(key.as_str()).release();
default:
throw std::runtime_error(
std::string("invalid key kind ") +
std::to_string(int(key.kind())) +
" in pytree dict; must be int or string");
}
}();
dict[py_key] = unflatten_internal(spec[i], leaves_it);
}
return std::move(dict);
}
case Kind::Leaf: {
py::object o =
py::reinterpret_borrow<py::object>(*std::forward<T>(leaves_it));
leaves_it++;
return o;
}
case Kind::None: {
return py::none();
}
}
throw std::runtime_error(
std::string("invalid spec kind ") + std::to_string(int(spec.kind())) +
" in unflatten_internal");
}
public:
explicit PyTree(PyTreeSpec spec) : spec_(std::move(spec)) {}
const PyTreeSpec& spec() const {
return spec_;
}
static PyTree py_from_str(std::string spec) {
return PyTree(from_str<PyAux>(spec));
}
StrTreeSpec py_to_str() const {
return to_str(spec_);
}
static std::pair<std::vector<py::object>, std::unique_ptr<PyTree>>
tree_flatten(py::handle x) {
std::vector<py::object> leaves{};
PyTreeSpec spec{};
flatten_internal(x, leaves, spec);
refresh_leaves_num(spec);
return {std::move(leaves), std::make_unique<PyTree>(std::move(spec))};
}
static py::object tree_unflatten(py::iterable leaves, py::object o) {
return o.cast<PyTree*>()->tree_unflatten(leaves);
}
template <typename T>
py::object tree_unflatten(T leaves) const {
return unflatten_internal(spec_, leaves.begin());
}
bool operator==(const PyTree& rhs) {
return spec_ == rhs.spec_;
}
size_t leaves_num() const {
return refresh_leaves_num(spec_);
}
};
inline std::pair<std::vector<py::object>, std::unique_ptr<PyTree>> tree_flatten(
py::handle x) {
return PyTree::tree_flatten(x);
}
inline py::object tree_unflatten(py::iterable leaves, py::object o) {
return PyTree::tree_unflatten(leaves, o);
}
static py::object tree_map(py::function& fn, py::handle x) {
auto p = tree_flatten(x);
const auto& leaves = p.first;
const auto& pytree = p.second;
std::vector<py::handle> vec;
for (const py::handle& h : leaves) {
vec.push_back(fn(h));
}
return pytree->tree_unflatten(vec);
}
static std::unique_ptr<PyTree> py_from_str(std::string spec) {
return std::make_unique<PyTree>(from_str<PyAux>(spec));
}
static py::object broadcast_to_and_flatten(
py::object x,
py::object py_tree_spec) {
auto p = tree_flatten(x);
const auto& x_leaves = p.first;
const auto& x_spec = p.second->spec();
PyTree* tree_spec = py_tree_spec.cast<PyTree*>();
py::list ret;
struct StackItem {
const PyTreeSpec* tree_spec_node;
const PyTreeSpec* x_spec_node;
const size_t x_leaves_offset;
};
std::stack<StackItem> stack;
stack.push({&tree_spec->spec(), &x_spec, 0u});
while (!stack.empty()) {
const auto top = stack.top();
stack.pop();
if (top.x_spec_node->isLeaf()) {
for (size_t i = 0; i < top.tree_spec_node->leaves_num(); ++i) {
ret.append(x_leaves[top.x_leaves_offset]);
}
} else {
const auto kind = top.tree_spec_node->kind();
if (kind != top.x_spec_node->kind()) {
return py::none();
}
const size_t child_num = top.tree_spec_node->size();
if (child_num != top.x_spec_node->size()) {
return py::none();
}
size_t x_leaves_offset =
top.x_leaves_offset + top.x_spec_node->leaves_num();
auto fn_i = [&](size_t i) {
x_leaves_offset -= (*top.x_spec_node)[i].leaves_num();
stack.push(
{&(*top.tree_spec_node)[i],
&(*top.x_spec_node)[i],
x_leaves_offset});
};
if (Kind::Dict == kind) {
for (size_t i = child_num - 1; i < child_num; --i) {
if (top.tree_spec_node->key(i) != top.x_spec_node->key(i)) {
return py::none();
}
fn_i(i);
}
} else {
for (size_t i = child_num - 1; i < child_num; --i) {
fn_i(i);
}
}
}
}
return std::move(ret);
}
} // namespace
PYBIND11_MODULE(pybindings, m) {
m.def("tree_flatten", &tree_flatten, py::arg("tree"));
m.def("tree_unflatten", &tree_unflatten, py::arg("leaves"), py::arg("tree"));
m.def("tree_map", &tree_map);
m.def("from_str", &py_from_str);
m.def("broadcast_to_and_flatten", &broadcast_to_and_flatten);
m.def("register_custom", &PyTypeRegistry::register_custom_type);
py::class_<PyTree>(m, "TreeSpec")
.def("from_str", &PyTree::py_from_str)
.def(
"tree_unflatten",
static_cast<py::object (PyTree::*)(py::iterable leaves) const>(
&PyTree::tree_unflatten))
.def("__repr__", &PyTree::py_to_str)
.def("__eq__", &PyTree::operator==)
.def("to_str", &PyTree::py_to_str)
.def("num_leaves", &PyTree::leaves_num);
}
} // namespace pytree
} // namespace extension
} // namespace executorch