Skip to content

Add tree walking functions #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ StaticArrays = "0.9, 0.10, 0.11, 0.12, 1.0"
julia = "1.6"

[extras]
GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "Mmap", "Tensors", "Test"]
test = ["GeometryBasics", "LinearAlgebra", "Mmap", "StableRNGs", "Tensors", "Test"]
5 changes: 4 additions & 1 deletion src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ include("inrange.jl")
include("hyperspheres.jl")
include("hyperrectangles.jl")
include("utilities.jl")
include("tree_ops.jl")
export root, treeindex, eachtreeindex, leafpoints, leaf_points_indices, region, isleaf, isroot, skip_regions, children, parent, nextsibling, prevsibling, points

include("brute_tree.jl")
include("kd_tree.jl")
include("ball_tree.jl")
include("tree_ops.jl")


for dim in (2, 3)
tree = KDTree(rand(dim, 10))
Expand Down
42 changes: 29 additions & 13 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,24 @@ function _knn(tree::BallTree,
return
end

@inline function region(T::BallTree)
if length(T.hyper_spheres) == 0
return _infinite_hypersphere(eltype(T.hyper_spheres))
else
return T.hyper_spheres[1]
end
end
@inline function _split_regions(tree::BallTree, ::HyperSphere, index::Int)
# tree = tr[]
r1 = tree.hyper_spheres[getleft(index)]
r2 = tree.hyper_spheres[getright(index)]
return r1, r2
end
@inline function _parent_region(tree::BallTree, ::HyperSphere, index::Int)
# tree = tr[]
parent = getparent(index)
return tree.hyper_spheres[parent]
end

function knn_kernel!(tree::BallTree{V},
index::Int,
Expand Down Expand Up @@ -179,20 +197,17 @@ function _inrange(tree::BallTree{V},
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V}
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
return inrange_kernel!(tree, root(tree), point, ball, idx_in_ball) # Call the recursive range finders
end

function inrange_kernel!(tree::BallTree,
index::Int,
function inrange_kernel!(tree::BallTree,
node::NNTreeNode,
point::AbstractVector,
query_ball::HyperSphere,
idx_in_ball::Union{Nothing, Vector{<:Integer}})

if index > length(tree.hyper_spheres)
return 0
end

sphere = tree.hyper_spheres[index]
sphere = region(node)
# tree = NearestNeighbors.tree(node) # give fully specified function name to avoid

# If the query ball in the bounding sphere for the current sub tree
# do not intersect we can disrecard the whole subtree
Expand All @@ -201,20 +216,21 @@ function inrange_kernel!(tree::BallTree,
end

# At a leaf node, check all points in the leaf node
if isleaf(tree.tree_data.n_internal_nodes, index)
return add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true)
if isleaf(tree, node)
return add_points_inrange!(idx_in_ball, tree, treeindex(node), point, query_ball.r, true)
end

count = 0

# The query ball encloses the sub tree bounding sphere. Add all points in the
# sub tree without checking the distance function.
if encloses(tree.metric, sphere, query_ball)
count += addall(tree, index, idx_in_ball)
count += addall(tree, node, idx_in_ball)
else
# Recursively call the left and right sub tree.
count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
left, right = children(tree, node)
count += inrange_kernel!(tree, left, point, query_ball, idx_in_ball)
count += inrange_kernel!(tree, right, point, query_ball, idx_in_ball)
end
return count
end
7 changes: 7 additions & 0 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ function knn_kernel!(tree::BruteTree{V},
end
end

# Custom implementation for BruteTree
isleaf(_::BruteTree, _::NNTreeNode) = true
leafpoints(tree::BruteTree, _::NNTreeNode) = tree.data
leaf_points_indices(tree::BruteTree, _::NNTreeNode) = eachindex(tree.data)
eachtreeindex(_::BruteTree) = 1:0 # empty list...
region(tree::BruteTree) = compute_bbox(tree.data)

function _inrange(tree::BruteTree,
point::AbstractVector,
radius::Number,
Expand Down
7 changes: 7 additions & 0 deletions src/hyperspheres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ end
HyperSphere(center::SVector{N,T1}, r) where {N, T1} = HyperSphere(center, convert(T1, r))
HyperSphere(center::AbstractVector, r) = HyperSphere(SVector{length(center)}(center), r)

function _infinite_hypersphere(::Type{HyperSphere{N,T}}) where {N, T}
return HyperSphere{N,T}(
ntuple(i->zero(T), Val(N)),
convert(T, Inf)
)
end

@inline function intersects(m::Metric,
s1::HyperSphere{N},
s2::HyperSphere{N}) where {N}
Expand Down
103 changes: 97 additions & 6 deletions src/kd_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ struct KDTree{V <: AbstractVector, M <: MinkowskiMetric, T, TH} <: NNTree{V,M}
metric::M
split_vals::Vector{T}
split_dims::Vector{UInt16}
split_minmax::Vector{Tuple{T,T}}
tree_data::TreeData
reordered::Bool
end
Expand All @@ -30,6 +31,7 @@ function KDTree(data::AbstractVector{V},
indices = collect(1:n_p)
split_vals = Vector{eltype(V)}(undef, tree_data.n_internal_nodes)
split_dims = Vector{UInt16}(undef, tree_data.n_internal_nodes)
split_minmax = Vector{Tuple{eltype(V),eltype(V)}}(undef, tree_data.n_internal_nodes)

if reorder
indices_reordered = Vector{Int}(undef, n_p)
Expand All @@ -56,7 +58,7 @@ function KDTree(data::AbstractVector{V},
hyper_rec = compute_bbox(data)

# Call the recursive KDTree builder
build_KDTree(1, data, data_reordered, hyper_rec, split_vals, split_dims, indices, indices_reordered,
build_KDTree(1, data, data_reordered, hyper_rec, split_vals, split_dims, split_minmax, indices, indices_reordered,
1:length(data), tree_data, reorder)
if reorder
data = data_reordered
Expand All @@ -71,7 +73,7 @@ function KDTree(data::AbstractVector{V},
end
end

KDTree(storedata ? data : similar(data, 0), hyper_rec, indices, metric, split_vals, split_dims, tree_data, reorder)
KDTree(storedata ? data : similar(data, 0), hyper_rec, indices, metric, split_vals, split_dims, split_minmax, tree_data, reorder)
end

function KDTree(data::AbstractVecOrMat{T},
Expand All @@ -97,6 +99,7 @@ function build_KDTree(index::Int,
hyper_rec::HyperRectangle,
split_vals::Vector{T},
split_dims::Vector{UInt16},
split_minmax::Vector{Tuple{T,T}},
indices::Vector{Int},
indices_reordered::Vector{Int},
range,
Expand Down Expand Up @@ -129,18 +132,21 @@ function build_KDTree(index::Int,

split_vals[index] = split_val
split_dims[index] = split_dim
split_minmax[index] = (hyper_rec.mins[split_dim], hyper_rec.maxes[split_dim])

# Call the left sub tree with an updated hyper rectangle
new_maxes = @inbounds setindex(hyper_rec.maxes, split_val, split_dim)
hyper_rec_left = HyperRectangle(hyper_rec.mins, new_maxes)
build_KDTree(getleft(index), data, data_reordered, hyper_rec_left, split_vals, split_dims,
indices, indices_reordered, first(range):mid_idx - 1, tree_data, reorder)
split_minmax, indices, indices_reordered,
first(range):mid_idx - 1, tree_data, reorder)

# Call the right sub tree with an updated hyper rectangle
new_mins = @inbounds setindex(hyper_rec.mins, split_val, split_dim)
hyper_rec_right = HyperRectangle(new_mins, hyper_rec.maxes)
build_KDTree(getright(index), data, data_reordered, hyper_rec_right, split_vals, split_dims,
indices, indices_reordered, mid_idx:last(range), tree_data, reorder)
split_minmax, indices, indices_reordered, mid_idx:last(range),
tree_data, reorder)
end


Expand Down Expand Up @@ -204,17 +210,48 @@ function knn_kernel!(tree::KDTree{V},
return
end

@inline function region(T::KDTree)
return T.hyper_rec
end

@inline function _split_regions(T::KDTree, R::HyperRectangle, index::Int)
# T = tr[]
split_val = T.split_vals[index]
split_dim = T.split_dims[index]

r1 = HyperRectangle(R.mins, @inbounds setindex(R.maxes, split_val, split_dim))
r2 = HyperRectangle(@inbounds(setindex(R.mins, split_val, split_dim)), R.maxes)
return r1, r2
end

@inline function _parent_region(T::KDTree, R::HyperRectangle, index::Int)
# T = tr[]
parent = getparent(index)
split_dim = T.split_dims[parent]
dimmin,dimmax = T.split_minmax[parent]
if getleft(parent) == index
r = HyperRectangle(
R.mins, @inbounds setindex(R.maxes, dimmax, split_dim)
)
else
r = HyperRectangle(
@inbounds(setindex(R.mins, dimmin, split_dim)), R.maxes
)
end
return r
end

function _inrange(tree::KDTree,
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}} = Int[])
init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point)
return inrange_kernel!(tree, 1, point, eval_op(tree.metric, radius, zero(init_min)), idx_in_ball,
tree.hyper_rec, init_min)
tree.hyper_rec, init_min)
end

# Explicitly check the distance between leaf node and point while traversing
function inrange_kernel!(tree::KDTree,
function inrange_kernel!(tree::KDTree,
index::Int,
point::AbstractVector,
r::Number,
Expand Down Expand Up @@ -270,3 +307,57 @@ function inrange_kernel!(tree::KDTree,
count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min)
return count
end


# Explicitly check the distance between leaf node and point while traversing
function inrange_kernel!(node::NNTreeNode,
point::AbstractVector,
r::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}},
min_dist)
# Point is outside hyper rectangle, skip the whole sub tree
if min_dist > r
return 0
end

# At a leaf node. Go through all points in node and add those in range
if isleaf(tree, node)
return add_points_inrange!(idx_in_ball, tree, node.index, point, r, false)
end

left, right = children(tree, node)
M = tree.metric
index = treeindex(node)

split_val = tree.split_vals[index]
split_dim = tree.split_dims[index]
p_dim = point[split_dim]
split_diff = p_dim - split_val

count = 0

if split_diff > 0 # Point is to the right of the split value
close = right
far = left
ddiff = max(zero(p_dim - hi), p_dim - hi)
else # Point is to the left of the split value
close = left
far = right
ddiff = max(zero(lo - p_dim), lo - p_dim)
end
# Call closer sub tree
count += inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist)

# TODO: We could potentially also keep track of the max distance
# between the point and the hyper rectangle and add the whole sub tree
# in case of the max distance being <= r similarly to the BallTree inrange method.
# It would be interesting to benchmark this on some different data sets.

# Call further sub tree with the new min distance
split_diff_pow = eval_pow(M, split_diff)
ddiff_pow = eval_pow(M, ddiff)
diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim)
new_min = eval_reduce(M, min_dist, diff_tot)
count += inrange_kernel!(tree, far, point, r, idx_in_ball, new_min)
return count
end
Loading