forked from Benjamin-Dobell/ge_tts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGraph.ttslua
101 lines (81 loc) · 4.21 KB
/
Graph.ttslua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
local TableUtils = require('ge_tts.TableUtils')
local Graph = {}
---@param node any
---@param getChildren fun(node: any): nil | any[] @A callback that when passed a node, must return a table of the node's children to be traversed, or nil.
---@param visitCallback fun(node: any) @If callback returns *any* value (including nil), then traversal is halted and the value is returned.
---@return thread
local function breadthVisitCoroutine(node, getChildren, visitCallback)
return coroutine.create(function()
---@type std__Packed<any>
local result = table.pack(visitCallback(node))
if result.n > 0 then
return table.unpack(result, 1, result.n)
end
local children = getChildren(node)
if children then
local visitDescendantCoroutines = TableUtils.map(--[[---@type any[] ]] children, function(child)
return breadthVisitCoroutine(child, getChildren, visitCallback)
end)
local stopped = true
repeat
coroutine.yield()
stopped = true
for _, visitDescendant in ipairs(visitDescendantCoroutines) do
if coroutine.status(visitDescendant) == 'suspended' then
result = table.pack(coroutine.resume(visitDescendant))
if #result > 1 then
return table.unpack(result, 2)
end
stopped = false
end
end
until stopped
end
end)
end
--- Performs preorder traversal over a node hierarchy starting at `node`. If `visitCallback` returns a value, traversal stops and the value is returned.
---@param node any
---@param getChildren fun(node: any): nil | any[] @A callback that when passed a node, must return a table of the node's children to be traversed, or nil.
---@param visitCallback fun(node: any) @If callback returns *any* value (including nil), then traversal is halted and the value is returned.
---@return any... @The return value of callback, or no return value if the entire tree traverses without callback returning a value.
function Graph.traverse(node, getChildren, visitCallback)
local result = table.pack(visitCallback(node))
if #result > 1 then
return table.unpack(result, 2)
end
local children = getChildren(node)
if children then
for _, child in ipairs(--[[---@not nil]] children) do
result = table.pack(Graph.traverse(child, getChildren, visitCallback))
if #result > 1 then
return table.unpack(result, 2)
end
end
end
end
--- Performs breadth first traversal over a node hierarchy starting at `node`. If `visitCallback` returns a value, traversal stops and the value is returned.
---@param root table
---@param getChildren fun(node: table): nil | any[] @A callback that when passed a node, must return a table of the node's children to be traversed, or nil.
---@param visitCallback fun(node: table) @If callback returns *any* value (including nil), then traversal is halted and the value is returned.
---@return any @The return value of callback, or no return value if the entire tree traverses without callback returning a value.
function Graph.breadthTraverse(root, getChildren, visitCallback)
local breadthVisit = breadthVisitCoroutine(root, getChildren, visitCallback)
repeat
local result = table.pack(coroutine.resume(breadthVisit))
if #result > 1 then
return --[[---@not nil]] table.unpack(result, 2)
end
until coroutine.status(breadthVisit) ~= 'suspended'
end
--- Perform breadth first search over a node hierarchy starting at `node`, and returning the first node for which `visitCallback` returns true.
---@param node table
---@param getChildren fun(node: table): nil | any[] @A callback that when passed a node, must return a table of the node's children to be traversed, may be length zero.
---@param visitCallback fun(node: table) @Condition callback
function Graph.find(root, getChildren, visitCallback)
return Graph.breadthTraverse(root, getChildren, function(node)
if visitCallback(node) then
return node
end
end)
end
return Graph