Skip to content

Commit 4a05747

Browse files
committed
Initial k-d-tree implementation
1 parent f5cc0eb commit 4a05747

File tree

3 files changed

+253
-0
lines changed

3 files changed

+253
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/build/

src/kdtree.hpp

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
#include <array>
2+
#include <string>
3+
#include <vector>
4+
#include <cmath>
5+
#include <numeric>
6+
#include <algorithm>
7+
#include <tuple>
8+
#include <queue>
9+
10+
namespace kdtree {
11+
12+
using std::vector;
13+
using std::array;
14+
using std::tuple;
15+
using std::get;
16+
using std::priority_queue;
17+
18+
using Real = double;
19+
using Size = array<Real, 1>::size_type;
20+
21+
template<Size DIMS>
22+
using Point = array<Real, DIMS>; // 160 byte
23+
24+
struct Division {
25+
int dim;
26+
Real p;
27+
};
28+
29+
struct KdTree {
30+
KdTree(int size, int depth)
31+
: depth(depth),
32+
//divisions(2 * (1 << depth) - 1) { // {sum_{i=0}^{depth} i}
33+
divisions((1 << depth) - 1) {
34+
elems.reserve(size);
35+
for (auto i = 0; i < size; ++i) {
36+
elems.push_back(i);
37+
}
38+
}
39+
int depth;
40+
vector<Division> divisions;
41+
vector<int> elems;
42+
};
43+
44+
using ElemIter = vector<int>::iterator;
45+
46+
Real square(Real v) { return v * v; }
47+
48+
template<Size DIMS>
49+
void buildImpl(ElemIter begin, Size size, ElemIter lastElem, vector<Division> &divs, int mydiv,
50+
const vector<Point<DIMS>> &points, int depth, int maxDepth) {
51+
auto end = std::min(begin + size, lastElem);
52+
if (maxDepth <= depth) {
53+
return;
54+
}
55+
int currentDim = 0;
56+
Real currentVariance = 0;
57+
for (int d = 0; d < DIMS; d++) {
58+
Real average = std::accumulate(begin, end, 0, [&points, d](Real sum, int i){ return sum + points[i][d]; }) / size;
59+
Real variance = std::accumulate(begin, end, 0, [&points, d, average](Real sum, int i){ return sum + square(points[i][d] - average); }) / size;
60+
if (variance > currentVariance) {
61+
currentDim = d;
62+
currentVariance = variance;
63+
}
64+
}
65+
auto mid = std::min(begin + size / 2, lastElem - 1);
66+
std::nth_element(begin, mid, end,
67+
[&points, currentDim](int a, int b){ return points[a][currentDim] < points[b][currentDim]; });
68+
divs[mydiv] = Division{currentDim, points[*mid][currentDim]};
69+
buildImpl(begin, size / 2, lastElem, divs, 2 * mydiv + 1, points, depth + 1, maxDepth);
70+
buildImpl(mid, size / 2, lastElem, divs, 2 * mydiv + 2, points, depth + 1, maxDepth);
71+
}
72+
73+
// round up
74+
int log2ceil(int n) {
75+
n -= 1;
76+
int i = 0;
77+
for (; n != 0; i++) {
78+
n >>= 1;
79+
}
80+
return i;
81+
}
82+
83+
// round up
84+
int log2floor(int n) {
85+
int i = 0;
86+
for (; n != 0; i++) {
87+
n >>= 1;
88+
}
89+
return i - 1;
90+
}
91+
92+
template<Size DIMS>
93+
KdTree buildKdTree(const vector<Point<DIMS>> &points) {
94+
auto depth = static_cast<int>(log2ceil(points.size()));
95+
KdTree tree{static_cast<int>(points.size()), depth};
96+
buildImpl(tree.elems.begin(), 1 << depth, tree.elems.end(), tree.divisions, 0, points, 0, tree.depth);
97+
return tree;
98+
}
99+
100+
template<Size DIMS>
101+
Real distSquared(Point<DIMS> p1, Point<DIMS> p2) {
102+
Real d = 0;
103+
for (auto i = 0; i < DIMS; ++i) {
104+
d += square(p1[i] - p2[i]);
105+
}
106+
return d;
107+
}
108+
109+
template<Size DIMS, typename Queue>
110+
void searchNNDown(Size divI, ElemIter begin, Size size,
111+
Size largestSizeToMoveUpTo,
112+
Real minDistInTree, array<Real, DIMS> &minDistInTreePerDim,
113+
Queue &nearest,
114+
const Point<DIMS> p, Size secoundLastLevel, ElemIter totalEnd, const KdTree &tree, const vector<Point<DIMS>> &points, int k) {
115+
116+
117+
while (divI < secoundLastLevel) {
118+
auto div = tree.divisions[divI];
119+
auto left = p[div.dim] < div.p;
120+
divI = 2 * divI + (left ? 1 : 2);
121+
size /= 2;
122+
begin = left ? begin : begin + size;
123+
}
124+
for (auto i = begin; i < std::min(begin + size, totalEnd); i++) {
125+
auto dist = distSquared(points[*i], p);
126+
if (nearest.size() < k) {
127+
nearest.push({dist, *i});
128+
} else if (dist < get<Real>(nearest.top())) {
129+
nearest.pop();
130+
nearest.push({dist, *i});
131+
}
132+
}
133+
searchNNUp(divI, begin, size,
134+
largestSizeToMoveUpTo,
135+
minDistInTree, minDistInTreePerDim,
136+
nearest,
137+
p, secoundLastLevel, totalEnd, tree, points, k);
138+
}
139+
140+
template<Size DIMS, typename Queue>
141+
void searchNNUp(Size divI, ElemIter begin, Size size,
142+
Size largestSizeToMoveUpTo,
143+
Real minDistInTree, array<Real, DIMS> minDistInTreePerDim /* optimize copy? */,
144+
Queue &nearest,
145+
const Point<DIMS> p, Size secoundLastLevel, ElemIter totalEnd, const KdTree &tree, const vector<Point<DIMS>> &points, int k) {
146+
147+
while (size < largestSizeToMoveUpTo) {
148+
auto div = tree.divisions[divI];
149+
auto isRightChild = divI % 2 == 0;
150+
if (nearest.size() < k || minDistInTree < get<Real>(nearest.top())) {
151+
auto distInTreeForDim = square(p[div.dim] - div.p);
152+
minDistInTree += distInTreeForDim - minDistInTreePerDim[div.dim];
153+
minDistInTreePerDim[div.dim] = distInTreeForDim;
154+
auto beginOther = begin + (isRightChild ? -size : +size);
155+
auto sizeOther = size;
156+
auto divOther = divI + (isRightChild ? -1 : +1);
157+
searchNNDown(divOther, beginOther, sizeOther,
158+
size,
159+
minDistInTree, minDistInTreePerDim,
160+
nearest,
161+
p, secoundLastLevel, totalEnd, tree, points, k);
162+
}
163+
auto beginUp = begin + (isRightChild ? -size : 0);
164+
auto sizeUp = size * 2;
165+
auto divUp = (divI / 2) - 1;
166+
begin = beginUp;
167+
size = sizeUp;
168+
divI = divUp;
169+
}
170+
}
171+
172+
template<Size DIMS>
173+
void knn(const vector<Point<DIMS>> &points, int k) {
174+
//priority_queue<tuple<Real, Size>, vector<tuple<Real, Size>, std::greater<int>> nearest{};
175+
auto tree = buildKdTree(points);
176+
std::cout << "tree build done" << '\n';
177+
auto secoundLastLevel = tree.divisions.size() / 2; // always floor b/c size is odd
178+
auto initSize = 1 << log2ceil(tree.elems.size());
179+
for (const auto p : points) {
180+
vector<tuple<Real, Size>> queueContainer{};
181+
queueContainer.reserve(k);
182+
auto compare = [](const tuple<Real, Size> &e1, const tuple<Real, Size> &e2){ return get<Real>(e1) > get<Real>(e2); };
183+
priority_queue<tuple<Real, Size>, decltype(queueContainer), decltype(compare)> nearest{compare, queueContainer};
184+
Size divI = 0; // index of division
185+
array<Real, DIMS> minDistInTreePerDim{}; // {} to zero initialize
186+
Real minDistInTree = 0;
187+
// search Down
188+
searchNNDown(divI, tree.elems.begin(), initSize,
189+
initSize,
190+
minDistInTree, minDistInTreePerDim,
191+
nearest,
192+
p, secoundLastLevel, tree.elems.end(), tree, points, k);
193+
std::cout << "elements for this point: " << nearest.size() << '\n';
194+
while (nearest.size()) {
195+
std::cout << ", " << get<Size>(nearest.top());
196+
nearest.pop();
197+
}
198+
std::cout << '\n';
199+
}
200+
}
201+
202+
}

src/main.cpp

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include <iostream>
2+
3+
#include "kdtree.hpp"
4+
5+
constexpr std::array<double, 1>::size_type dims = 5;
6+
7+
void gen_points(std::vector<std::array<double, dims>> &ps,
8+
std::array<double, dims> &p, int d, int n) {
9+
if (d == dims) {
10+
ps.push_back(p);
11+
} else {
12+
for (int i = 0; i < n; ++i) {
13+
p[d] = static_cast<double>(i);
14+
gen_points(ps, p, d + 1, n);
15+
}
16+
}
17+
}
18+
19+
void f() {
20+
std::vector<std::array<double, dims>> points;
21+
std::array<double, dims> dummy;
22+
gen_points(points, dummy, 0, 5);
23+
std::cout << "input generated\n";
24+
kdtree::knn(points, 5);
25+
}
26+
27+
#define assert(a) std::cout << (a) << '\n'
28+
29+
void testLog2() {
30+
assert(kdtree::log2ceil(1024) == 10);
31+
assert(kdtree::log2ceil(1023) == 10);
32+
assert(kdtree::log2ceil(1025) == 11);
33+
assert(kdtree::log2ceil(14) == 4);
34+
assert(kdtree::log2ceil(15) == 4);
35+
assert(kdtree::log2ceil(16) == 4);
36+
assert(kdtree::log2ceil(17) == 5);
37+
38+
assert(kdtree::log2floor(1024) == 10);
39+
assert(kdtree::log2floor(1023) == 9);
40+
assert(kdtree::log2floor(1025) == 10);
41+
assert(kdtree::log2floor(14) == 3);
42+
assert(kdtree::log2floor(15) == 3);
43+
assert(kdtree::log2floor(16) == 4);
44+
assert(kdtree::log2floor(17) == 4);
45+
}
46+
47+
int main() {
48+
f();
49+
//testLog2();
50+
}

0 commit comments

Comments
 (0)