Skip to content

Commit 5fcd36b

Browse files
committed
Add Lua+tree-sitter stack graph builder
This is the spackle that parses a source file using tree-sitter, and calls a Lua function with it and an empty stack graph. The Lua function can do whatever it wants to walk the parse tree and add nodes and edges to the graph.
1 parent 84d4bec commit 5fcd36b

File tree

7 files changed

+225
-20
lines changed

7 files changed

+225
-20
lines changed

stack-graphs/src/lua.rs

+32-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
//!
3232
//! let mut graph = StackGraph::new();
3333
//! lua.scope(|scope| {
34-
//! let graph = scope.create_userdata_ref_mut(&mut graph);
34+
//! let graph = graph.lua_ref_mut(&scope)?;
3535
//! process_graph.call(graph)
3636
//! })?;
3737
//! assert_eq!(graph.iter_nodes().count(), 3);
@@ -377,6 +377,8 @@ use std::num::NonZeroU32;
377377
use controlled_option::ControlledOption;
378378
use lsp_positions::Span;
379379
use mlua::AnyUserData;
380+
use mlua::Lua;
381+
use mlua::Scope;
380382
use mlua::UserData;
381383
use mlua::UserDataMethods;
382384

@@ -385,6 +387,35 @@ use crate::graph::File;
385387
use crate::graph::Node;
386388
use crate::graph::StackGraph;
387389

390+
impl StackGraph {
391+
// Returns a Lua wrapper for this stack graph. Takes ownership of the stack graph. If you
392+
// want to access the stack graph after your Lua code is done with it, use [`lua_ref_mut`]
393+
// instead.
394+
pub fn lua_value<'lua>(self, lua: &'lua Lua) -> Result<AnyUserData<'lua>, mlua::Error> {
395+
lua.create_userdata(self)
396+
}
397+
398+
// Returns a scoped Lua wrapper for this stack graph.
399+
pub fn lua_ref_mut<'lua, 'scope>(
400+
&'scope mut self,
401+
scope: &Scope<'lua, 'scope>,
402+
) -> Result<AnyUserData<'lua>, mlua::Error> {
403+
scope.create_userdata_ref_mut(self)
404+
}
405+
406+
// Returns a scoped Lua wrapper for a file in this stack graph.
407+
pub fn file_lua_ref_mut<'lua, 'scope>(
408+
&'scope mut self,
409+
file: Handle<File>,
410+
scope: &Scope<'lua, 'scope>,
411+
) -> Result<AnyUserData<'lua>, mlua::Error> {
412+
let graph_ud = self.lua_ref_mut(scope)?;
413+
let file_ud = scope.create_userdata(file)?;
414+
file_ud.set_user_value(graph_ud)?;
415+
Ok(file_ud)
416+
}
417+
}
418+
388419
impl UserData for StackGraph {
389420
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
390421
methods.add_function("file", |l, (graph_ud, name): (AnyUserData, String)| {

stack-graphs/tests/it/lua.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl CheckLua for mlua::Lua {
9090

9191
fn check(&self, graph: &mut StackGraph, chunk: &str) -> Result<(), mlua::Error> {
9292
self.scope(|scope| {
93-
let graph = scope.create_userdata_ref_mut(graph);
93+
let graph = graph.lua_ref_mut(&scope)?;
9494
self.load(chunk).set_name("test chunk").call(graph)
9595
})
9696
}

tree-sitter-stack-graphs/Cargo.toml

+7
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ lsp = [
4747
"tokio",
4848
"tower-lsp",
4949
]
50+
lua = [
51+
"mlua",
52+
"mlua-tree-sitter",
53+
"stack-graphs/lua",
54+
]
5055

5156
[dependencies]
5257
anyhow = "1.0"
@@ -63,6 +68,8 @@ indoc = { version = "1.0", optional = true }
6368
itertools = "0.10"
6469
log = "0.4"
6570
lsp-positions = { version="0.4", path="../lsp-positions", features=["tree-sitter"] }
71+
mlua = { version = "0.9", optional = true }
72+
mlua-tree-sitter = { version = "0.1", git="https://github.com/dcreager/ltreesitter", branch="mlua", optional = true }
6673
once_cell = "1"
6774
pathdiff = { version = "0.2.1", optional = true }
6875
regex = "1"

tree-sitter-stack-graphs/src/lib.rs

+30-18
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ use std::time::Duration;
357357
use std::time::Instant;
358358
use thiserror::Error;
359359
use tree_sitter::Parser;
360+
use tree_sitter::Tree;
360361
use tree_sitter_graph::functions::Functions;
361362
use tree_sitter_graph::graph::Edge;
362363
use tree_sitter_graph::graph::Graph;
@@ -375,6 +376,8 @@ pub mod ci;
375376
pub mod cli;
376377
pub mod functions;
377378
pub mod loader;
379+
#[cfg(feature = "lua")]
380+
pub mod lua;
378381
pub mod test;
379382
mod util;
380383

@@ -578,6 +581,29 @@ impl StackGraphLanguage {
578581
}
579582
}
580583

584+
pub(crate) fn parse_file(
585+
language: tree_sitter::Language,
586+
source: &str,
587+
cancellation_flag: &dyn CancellationFlag,
588+
) -> Result<Tree, BuildError> {
589+
let tree = {
590+
let mut parser = Parser::new();
591+
parser.set_language(language)?;
592+
let ts_cancellation_flag = TreeSitterCancellationFlag::from(cancellation_flag);
593+
// The parser.set_cancellation_flag` is unsafe, because it does not tie the
594+
// lifetime of the parser to the lifetime of the cancellation flag in any way.
595+
// To make it more obvious that the parser does not outlive the cancellation flag,
596+
// it is put into its own block here, instead of extending to the end of the method.
597+
unsafe { parser.set_cancellation_flag(Some(ts_cancellation_flag.as_ref())) };
598+
parser.parse(source, None).ok_or(BuildError::ParseError)?
599+
};
600+
let parse_errors = ParseError::into_all(tree);
601+
if parse_errors.errors().len() > 0 {
602+
return Err(BuildError::ParseErrors(parse_errors));
603+
}
604+
Ok(parse_errors.into_tree())
605+
}
606+
581607
pub struct Builder<'a> {
582608
sgl: &'a StackGraphLanguage,
583609
stack_graph: &'a mut StackGraph,
@@ -615,24 +641,7 @@ impl<'a> Builder<'a> {
615641
globals: &'a Variables<'a>,
616642
cancellation_flag: &dyn CancellationFlag,
617643
) -> Result<(), BuildError> {
618-
let tree = {
619-
let mut parser = Parser::new();
620-
parser.set_language(self.sgl.language)?;
621-
let ts_cancellation_flag = TreeSitterCancellationFlag::from(cancellation_flag);
622-
// The parser.set_cancellation_flag` is unsafe, because it does not tie the
623-
// lifetime of the parser to the lifetime of the cancellation flag in any way.
624-
// To make it more obvious that the parser does not outlive the cancellation flag,
625-
// it is put into its own block here, instead of extending to the end of the method.
626-
unsafe { parser.set_cancellation_flag(Some(ts_cancellation_flag.as_ref())) };
627-
parser
628-
.parse(self.source, None)
629-
.ok_or(BuildError::ParseError)?
630-
};
631-
let parse_errors = ParseError::into_all(tree);
632-
if parse_errors.errors().len() > 0 {
633-
return Err(BuildError::ParseErrors(parse_errors));
634-
}
635-
let tree = parse_errors.into_tree();
644+
let tree = parse_file(self.sgl.language, self.source, cancellation_flag)?;
636645

637646
let mut globals = Variables::nested(globals);
638647
if globals.get(&ROOT_NODE_VAR.into()).is_none() {
@@ -826,6 +835,9 @@ pub enum BuildError {
826835
LanguageError(#[from] tree_sitter::LanguageError),
827836
#[error("Expected exported symbol scope in {0}, got {1}")]
828837
SymbolScopeError(String, String),
838+
#[cfg(feature = "lua")]
839+
#[error(transparent)]
840+
LuaError(#[from] mlua::Error),
829841
}
830842

831843
impl From<stack_graphs::CancellationError> for BuildError {

tree-sitter-stack-graphs/src/lua.rs

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// -*- coding: utf-8 -*-
2+
// ------------------------------------------------------------------------------------------------
3+
// Copyright © 2023, stack-graphs authors.
4+
// Licensed under either of Apache License, Version 2.0, or MIT license, at your option.
5+
// Please see the LICENSE-APACHE or LICENSE-MIT files in this distribution for license details.
6+
// ------------------------------------------------------------------------------------------------
7+
8+
//! Construct stack graphs using a Lua script that consumes a tree-sitter parse tree
9+
10+
use std::borrow::Cow;
11+
12+
use mlua::Lua;
13+
use mlua_tree_sitter::Module;
14+
use mlua_tree_sitter::WithSource;
15+
use stack_graphs::arena::Handle;
16+
use stack_graphs::graph::File;
17+
use stack_graphs::graph::StackGraph;
18+
19+
use crate::parse_file;
20+
use crate::BuildError;
21+
use crate::CancellationFlag;
22+
23+
/// Holds information about how to construct stack graphs for a particular language.
24+
pub struct StackGraphLanguageLua {
25+
language: tree_sitter::Language,
26+
lua_source: Cow<'static, [u8]>,
27+
lua_source_name: String,
28+
}
29+
30+
impl StackGraphLanguageLua {
31+
/// Creates a new stack graph language for the given language, loading the Lua stack graph
32+
/// construction rules from a static string.
33+
pub fn from_static_str(
34+
language: tree_sitter::Language,
35+
lua_source: &'static [u8],
36+
lua_source_name: &str,
37+
) -> StackGraphLanguageLua {
38+
StackGraphLanguageLua {
39+
language,
40+
lua_source: Cow::from(lua_source),
41+
lua_source_name: lua_source_name.to_string(),
42+
}
43+
}
44+
45+
/// Creates a new stack graph language for the given language, loading the Lua stack graph
46+
/// construction rules from a string.
47+
pub fn from_str(
48+
language: tree_sitter::Language,
49+
lua_source: &[u8],
50+
lua_source_name: &str,
51+
) -> StackGraphLanguageLua {
52+
StackGraphLanguageLua {
53+
language,
54+
lua_source: Cow::from(lua_source.to_vec()),
55+
lua_source_name: lua_source_name.to_string(),
56+
}
57+
}
58+
59+
pub fn language(&self) -> tree_sitter::Language {
60+
self.language
61+
}
62+
63+
pub fn lua_source_name(&self) -> &str {
64+
&self.lua_source_name
65+
}
66+
67+
pub fn lua_source(&self) -> &Cow<'static, [u8]> {
68+
&self.lua_source
69+
}
70+
71+
/// Executes the graph construction rules for this language against a source file, creating new
72+
/// nodes and edges in `stack_graph`. Any new nodes that we create will belong to `file`.
73+
/// (The source file must be implemented in this language, otherwise you'll probably get a
74+
/// parse error.)
75+
pub fn build_stack_graph_into<'a>(
76+
&'a self,
77+
stack_graph: &'a mut StackGraph,
78+
file: Handle<File>,
79+
source: &'a str,
80+
cancellation_flag: &'a dyn CancellationFlag,
81+
) -> Result<(), BuildError> {
82+
// Create a Lua environment and load the language's stack graph rules.
83+
// TODO: Sandbox the Lua environment
84+
let mut lua = Lua::new();
85+
lua.open_ltreesitter(false)?;
86+
lua.load(self.lua_source.as_ref())
87+
.set_name(&self.lua_source_name)
88+
.exec()?;
89+
let process: mlua::Function = lua.globals().get("process")?;
90+
91+
// Parse the source using the requested grammar.
92+
let tree = parse_file(self.language, source, cancellation_flag)?;
93+
let tree = tree.with_source(source.as_bytes());
94+
95+
// Invoke the Lua `process` function with the parsed tree and the stack graph file.
96+
// TODO: Add a debug hook that checks the cancellation flag during execution
97+
lua.scope(|scope| {
98+
let file = stack_graph.file_lua_ref_mut(file, scope)?;
99+
process.call((tree, file))
100+
})?;
101+
Ok(())
102+
}
103+
}
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// -*- coding: utf-8 -*-
2+
// ------------------------------------------------------------------------------------------------
3+
// Copyright © 2023, stack-graphs authors.
4+
// Licensed under either of Apache License, Version 2.0, or MIT license, at your option.
5+
// Please see the LICENSE-APACHE or LICENSE-MIT files in this distribution for license details.
6+
// ------------------------------------------------------------------------------------------------
7+
8+
use stack_graphs::graph::StackGraph;
9+
use tree_sitter_stack_graphs::lua::StackGraphLanguageLua;
10+
use tree_sitter_stack_graphs::NoCancellation;
11+
12+
use crate::edges::check_stack_graph_edges;
13+
use crate::nodes::check_stack_graph_nodes;
14+
15+
// This doesn't build a very _interesting_ stack graph, but it does test that the end-to-end
16+
// spackle all works correctly.
17+
#[test]
18+
fn can_build_stack_graph_from_lua() {
19+
const LUA: &[u8] = br#"
20+
function process(parsed, file)
21+
-- TODO: fill in the definiens span from the parse tree root
22+
local module = file:internal_scope_node()
23+
module:add_edge_from(file:root_node())
24+
end
25+
"#;
26+
let python = "pass";
27+
28+
let mut graph = StackGraph::new();
29+
let file = graph.get_or_create_file("test.py");
30+
let language =
31+
StackGraphLanguageLua::from_static_str(tree_sitter_python::language(), LUA, "test");
32+
language
33+
.build_stack_graph_into(&mut graph, file, python, &NoCancellation)
34+
.expect("Failed to build graph");
35+
36+
check_stack_graph_nodes(
37+
&graph,
38+
file,
39+
&[
40+
"[test.py(0) scope]", //
41+
],
42+
);
43+
check_stack_graph_edges(
44+
&graph,
45+
&[
46+
"[root] -0-> [test.py(0) scope]", //
47+
],
48+
);
49+
}

tree-sitter-stack-graphs/tests/it/main.rs

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ mod loader;
1919
mod nodes;
2020
mod test;
2121

22+
#[cfg(feature = "lua")]
23+
mod lua;
24+
2225
pub(self) fn build_stack_graph(
2326
python_source: &str,
2427
tsg_source: &str,

0 commit comments

Comments
 (0)