diff --git a/java/com/metabase/macaw/AstWalker.java b/java/com/metabase/macaw/AstWalker.java index 4c475d5..9672d17 100644 --- a/java/com/metabase/macaw/AstWalker.java +++ b/java/com/metabase/macaw/AstWalker.java @@ -167,6 +167,7 @@ public class AstWalker implements SelectVisitor, FromItemVisitor, Expressio SelectItemVisitor, StatementVisitor, GroupByVisitor { public enum CallbackKey { + EVERY_NODE, ALIAS, ALL_COLUMNS, ALL_TABLE_COLUMNS, @@ -277,6 +278,10 @@ public void invokeCallback(CallbackKey key, Object visitedItem) { //noinspection unchecked this.acc = (Acc) callback.invoke(acc, visitedItem, this.contextStack); } + + if (key != EVERY_NODE) { + invokeCallback(EVERY_NODE, visitedItem); + } } private void pushContext(QueryScopeLabel label) { diff --git a/java/com/metabase/macaw/SimpleParser.java b/java/com/metabase/macaw/SimpleParser.java new file mode 100644 index 0000000..9a61f97 --- /dev/null +++ b/java/com/metabase/macaw/SimpleParser.java @@ -0,0 +1,294 @@ +package com.metabase.macaw; + +import clojure.lang.Keyword; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.GreaterThan; +import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.schema.Table; +import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.select.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Return a simplified query representation we can work with further, if possible. + */ +@SuppressWarnings({ + "rawtypes", // will let us return Persistent datastructures eventually + "unchecked", // lets us use raw types without casting + "PatternVariableCanBeUsed", "IfCanBeSwitch"} // don't force a newer JVM version +) +public final class SimpleParser { + + public static Map maybeParse(Statement statement) { + try { + if (statement instanceof Select) { + return maybeParse((Select) statement); + } + // This is not a query. + return null; + } catch (IllegalArgumentException e) { + // This query uses features that we do not yet support translating. + System.out.println(e.getMessage()); + return null; + } + } + + private static Map maybeParse(Select select) { + PlainSelect ps = select.getPlainSelect(); + if (ps != null) { + return maybeParse(ps); + } + // We don't support more complex kinds of select statements yet. + throw new IllegalArgumentException("Unsupported query type " + select.getClass().getName()); + } + + private static Map maybeParse(PlainSelect select) { + // any of these - nope out + if (select.getDistinct() != null || + select.getFetch() != null || + select.getFirst() != null || + select.getForClause() != null || + select.getForMode() != null || + select.getForUpdateTable() != null || + select.getForXmlPath() != null || + select.getHaving() != null || + select.getIntoTables() != null || + select.getIsolation() != null || + select.getKsqlWindow() != null || + select.getLateralViews() != null || + select.getLimitBy() != null || + select.getMySqlHintStraightJoin() || + select.getMySqlSqlCacheFlag() != null || + select.getOffset() != null || + select.getOptimizeFor() != null || + select.getOracleHierarchical() != null || + select.getOracleHint() != null || + select.getSkip() != null || + select.getTop() != null || + select.getWait() != null || + select.getWindowDefinitions() != null || + select.getWithItemsList() != null) { + throw new IllegalArgumentException("Unsupported query feature(s)"); + } + + Map m = new HashMap(); + m.put("select", select.getSelectItems().stream().map(SimpleParser::parse).toList()); + + if (select.getFromItem() != null) { + ArrayList from = new ArrayList(); + from.add(parse(select.getFromItem())); + List joins = select.getJoins(); + if (joins != null) { + joins.stream().map(SimpleParser::parse).forEach(from::add); + } + m.put("from", from); + } + + Expression where = select.getWhere(); + if (where != null) { + m.put("where", parseWhere(where)); + } + GroupByElement gbe = select.getGroupBy(); + if (gbe != null) { + m.put("group-by", parse(gbe)); + } + List obe = select.getOrderByElements(); + if (obe != null) { + m.put("order-by", obe.stream().map(SimpleParser::parse).toList()); + } + Limit limit = select.getLimit(); + if (limit != null) { + m.put("limit", parse(limit)); + } + return m; + } + + private static Map parse(Join join) { + if (join.isApply() || + join.isCross() || + join.isGlobal() || + join.isSemi() || + join.isStraight() || + join.isWindowJoin() || + join.getJoinHint() != null || + join.getJoinWindow() != null || + !join.getUsingColumns().isEmpty()) { + throw new IllegalArgumentException("Unsupported join expression"); + } + assert(join.isSimple()); + + if (join.isFull() || + join.isLeft() || + join.isRight()) { + // TODO + throw new IllegalArgumentException("Join type not supported yet"); + } + assert(join.isInnerJoin()); + + if (!join.getOnExpressions().isEmpty()) { + throw new IllegalArgumentException("Only unconditional joins supported for now"); + } + + return parse(join.getFromItem()); + } + + private static Map parse(FromItem fromItem) { + // We don't support table aliases yet - which is fine since pMBQL doesn't generate them + // fromItem.getAlias(); + if (fromItem instanceof Table) { + return parse((Table) fromItem); + } + throw new IllegalArgumentException("Unsupported from clause"); + } + + private static Long parse(Limit limit) { + Expression rc = limit.getRowCount(); + if (limit.getOffset() != null || limit.getByExpressions() != null || !(rc instanceof LongValue)) { + throw new IllegalArgumentException("Unsupported limit clause"); + } + return ((LongValue) limit.getRowCount()).getValue(); + } + + private static Map parse(OrderByElement elem) { + if (elem.getNullOrdering() != null) { + throw new IllegalArgumentException("Unsupported order by clause(s)"); + } + Expression e = elem.getExpression(); + if (e instanceof Column) { + return parse((Column) e); + } + throw new IllegalArgumentException("Unsupported order by clause(s)"); + } + + private static List parseWhere(Expression where) { + // oh my lord, what a mission to convert all these, definitely some clojure metaprogramming would be nice + if (where instanceof ComparisonOperator) { + ComparisonOperator co = (ComparisonOperator) where; + if (co.getOldOracleJoinSyntax() > 0 || co.getOraclePriorPosition() > 0) { + throw new IllegalArgumentException("Unsupported where clause"); + } + ArrayList form = new ArrayList(); + // if we handle ComparisonOperator then we could get the private field "operator" and rely on that. + if (co instanceof EqualsTo) { + form.add(Keyword.find("=")); + } else if (co instanceof GreaterThan) { + form.add(Keyword.find("<")); + } else if (co instanceof GreaterThanEquals) { + form.add(Keyword.find("<")); + } + + form.add(parseComparisonExpression(co.getLeftExpression())); + form.add(parseComparisonExpression(co.getRightExpression())); + return form; + } + + throw new IllegalArgumentException("Unsupported where clause"); + } + + private static Object parseComparisonExpression(Expression expr) { + if (expr instanceof Column) { + return parse((Column) expr); + } else if (expr instanceof LongValue) { + return ((LongValue) expr).getValue(); + } + throw new IllegalArgumentException("Unsupported expression in comparison"); + } + + private static List parse(GroupByElement groupBy) { + if (groupBy == null) { + return null; + } + if (groupBy.getGroupingSets() != null && !groupBy.getGroupingSets().isEmpty()) { + throw new IllegalArgumentException("Unsupported group by clause(s)"); + } + return groupBy.getGroupByExpressionList().stream().map(SimpleParser::parseGroupByExpr).toList(); + } + + private static Map parseGroupByExpr(Object o) { + if (o instanceof Column) { + return parse((Column) o); + } + throw new IllegalArgumentException("Unsupported group by expression(s)"); + } + + private static final Map STAR = new HashMap(); + + static { + STAR.put("type", "*"); + } + + private static Map parse(AllColumns expr) { + if (expr.getExceptColumns() != null || expr.getReplaceExpressions() != null) { + throw new IllegalArgumentException("Unsupported expression:" + expr); + } + return STAR; + } + + private static Map parse(Table t) { + Map m = new HashMap(); + String s = t.getSchemaName(); + if (s != null) { + m.put("schema", s); + } + m.put("table", t.getName()); + return m; + } + + private static Map parse(Column c) { + Map m = new HashMap(); + m.put("type", "column"); + Table t = c.getTable(); + if (t != null) { + String s = t.getSchemaName(); + if (s != null) { + m.put("schema", s); + } + m.put("table", t.getName()); + } + m.put("column", c.getColumnName()); + return m; + } + + private static Map parse(SelectItem item) { + // We ignore the alias for now, but could use this in future to create a custom expression with the given name. + // item.getAlias(); + + Expression exp = item.getExpression(); + if (exp instanceof AllColumns) { + return parse((AllColumns) exp); + } else if (exp instanceof Column) { + return parse((Column) exp); + } else if (exp instanceof Function) { + Function f = (Function) exp; + if (f.getName().equalsIgnoreCase("COUNT")) { + Map m = new HashMap(); + if (f.getParameters().size() != 1) { + throw new IllegalArgumentException("Malformed COUNT expression"); + } + Expression p = f.getParameters().getFirst(); + if (p instanceof AllColumns) { + m.put("type", "count"); + m.put("column", "*"); + return m; + } + // If there's a concrete column given, we can add an implicit non-null clause for it. + // For now, we simply don't support more complex cases. + } + // Fall through if it's not supported + } + + // The next step would be looking at the full list of expressions that we support. + throw new IllegalArgumentException("Unsupported expression(s) in select"); + } + + +} diff --git a/src/macaw/scope_experiments.clj b/src/macaw/scope_experiments.clj new file mode 100644 index 0000000..ed4b9c7 --- /dev/null +++ b/src/macaw/scope_experiments.clj @@ -0,0 +1,102 @@ +(ns macaw.scope-experiments + (:require + [macaw.core :as m] + [macaw.walk :as mw]) + (:import + (com.metabase.macaw SimpleParser) + (java.util List Map) + (net.sf.jsqlparser.schema Column Table) + (net.sf.jsqlparser.statement.select SelectItem))) + +(defn- java->clj + "Recursively converts Java ArrayList and HashMap to Clojure vector and map." + [java-obj] + (condp instance? java-obj + List (mapv java->clj java-obj) + Map (into {} (for [[k v] java-obj] + [(keyword k) (java->clj v)])) + java-obj)) + +(defn query-map [sql] + (java->clj (SimpleParser/maybeParse (m/parsed-query sql)))) + +(defn- node->clj [node] + (cond + (instance? SelectItem node) [:select-item (.getAlias node) (.getExpression node)] + (instance? Column node) [:column + (some-> (.getTable node) .getName) + (.getColumnName node)] + (instance? Table node) [:table (.getName node)] + :else [(type node) node])) + +(defn semantic-map + "Name is a bit of a shame, for now this is a fairly low level representation of how we walk the query" + [sql] + (mw/fold-query (m/parsed-query sql) + {:every-node (fn [acc node ctx] + (let [id (m/scope-id (first ctx)) + node (node->clj node)] + (-> acc + (update-in [:scopes id] + (fn [scope] + (-> scope + (update :path #(or % (mapv m/scope-label (reverse ctx)))) + (update :children (fnil conj []) node)))) + ((fn [acc'] + (if-let [parent-id (some-> (second ctx) m/scope-id)] + (-> acc' + (update :parents assoc id parent-id) + (update-in [:children parent-id] (fnil conj #{}) id)) + acc'))) + (update :sequence (fnil conj []) [id node #_(mapv m/scope-label (reverse ctx))]))))} + {:scopes {} ;; id -> {:path [labels], :children [nodes]} + :parents {} ;; what scope is this inside? + :children {} ;; what scopes are inside? + :sequence []})) ;; [scope-id, node] + +(defn- ->descendants + "Given a direct mapping, get back the transitive mapping" + [parent->children] + (reduce + (fn [acc parent-id] + (let [children (parent->children parent-id)] + (assoc acc parent-id (into (set children) (mapcat acc) children)))) + {} + ;; guarantee we process each node before its parent + (reverse (sort (keys parent->children))))) + +(defn fields->tables-in-scope + "Build a map of each to field to all the tables that are in scope when its referenced" + [sql] + (let [sm (semantic-map sql) + tables (filter (comp #{:table} first second) (:sequence sm)) + scope->tables (reduce + (fn [m [scope-id [_ table-name]]] + (update m scope-id (fnil conj #{}) table-name)) + {} + tables) + scope->descendants (->descendants (:children sm)) + scope->nested-tables (reduce + (fn [m parent-id] + (assoc m parent-id + (into (set (scope->tables parent-id)) (mapcat scope->tables (scope->descendants parent-id))))) + {} + (keys (:scopes sm))) + columns (filter (comp #{:column} first second) (:sequence sm))] + + (vec (distinct (for [[scope-id [_ table-name column-name]] columns] + [[scope-id column-name] + (if table-name + #{table-name} + (scope->nested-tables scope-id))]))))) + +(defn- ->vec [x] + (mapv x [:schema :table :column])) + +(defn fields-to-search + "Get a set of qualified columns. Where the qualification was uncertain, we enumerate all possibilities" + [f->ts] + (into (sorted-set-by (fn [x y] (compare (->vec x) (->vec y)))) + (mapcat (fn [[[_ column-name] table-names]] + (map #(hash-map :table % :column column-name) table-names))) + f->ts)) diff --git a/src/macaw/walk.clj b/src/macaw/walk.clj index 8bb1e95..2d4b96a 100644 --- a/src/macaw/walk.clj +++ b/src/macaw/walk.clj @@ -10,6 +10,7 @@ {:alias AstWalker$CallbackKey/ALIAS :column AstWalker$CallbackKey/COLUMN :column-qualifier AstWalker$CallbackKey/COLUMN_QUALIFIER + :every-node AstWalker$CallbackKey/EVERY_NODE :mutation AstWalker$CallbackKey/MUTATION_COMMAND :pseudo-table AstWalker$CallbackKey/PSEUDO_TABLES :table AstWalker$CallbackKey/TABLE diff --git a/test/macaw/core_test.clj b/test/macaw/core_test.clj index f28257f..fcf2e1e 100644 --- a/test/macaw/core_test.clj +++ b/test/macaw/core_test.clj @@ -10,7 +10,7 @@ [mb.hawk.assert-exprs]) (:import (clojure.lang ExceptionInfo) - (net.sf.jsqlparser.schema Table))) + (net.sf.jsqlparser.schema Column Table))) (set! *warn-on-reflection* true) @@ -619,7 +619,7 @@ from foo") raw-components))))) (comment - (require 'user) ;; kondo, really + (require 'user) ;; kondo, really (require '[clj-async-profiler.core :as prof]) (prof/serve-ui 8080) @@ -642,6 +642,7 @@ from foo") (anonymize-query "SELECT x FROM a") (anonymize-fixture :snowflakelet) + (require 'virgil) (require 'clojure.tools.namespace.repl) (virgil/watch-and-recompile ["java"] :post-hook clojure.tools.namespace.repl/refresh-all)) diff --git a/test/macaw/scope_experiments_test.clj b/test/macaw/scope_experiments_test.clj new file mode 100644 index 0000000..23757f7 --- /dev/null +++ b/test/macaw/scope_experiments_test.clj @@ -0,0 +1,110 @@ +(ns macaw.scope-experiments-test + (:require + [clojure.test :refer :all] + [macaw.scope-experiments :as mse])) + +(set! *warn-on-reflection* true) + +(deftest ^:parallel query-map-test + (is (= (mse/query-map "SELECT x FROM t") + {:select [{:column "x", :type "column"}] + :from [{:table "t"}]})) + + (is (= (mse/query-map "SELECT x FROM t WHERE y = 1") + {:select [{:column "x", :type "column"}] + :from [{:table "t"}] + :where [:= + {:column "y", :type "column"} + 1]})) + + (is (= (mse/query-map "SELECT x, z FROM t WHERE y = 1 GROUP BY z ORDER BY x DESC LIMIT 1") + {:select [{:column "x", :type "column"} {:column "z", :type "column"}], + :from [{:table "t"}], + :where [:= {:column "y", :type "column"} 1] + :group-by [{:column "z", :type "column"}], + :order-by [{:column "x", :type "column"}], + :limit 1,})) + + (is (= (mse/query-map "SELECT x FROM t1, t2") + {:select [{:column "x", :type "column"}], :from [{:table "t1"} {:table "t2"}]}))) + +(deftest ^:parallel semantic-map-test + (is (= (mse/semantic-map "select x from t, u, v left join w on w.id = v.id where t.id = u.id and u.id = v.id limit 3") + {:scopes {1 {:path ["SELECT"], :children [[:column nil "x"]]}, + 2 {:path ["SELECT" "FROM"], :children [[:table "t"]]}, + 4 {:path ["SELECT" "JOIN" "FROM"], :children [[:table "u"]]}, + 5 {:path ["SELECT" "JOIN" "FROM"], :children [[:table "v"]]}, + 6 {:path ["SELECT" "JOIN" "FROM"], :children [[:table "w"]]}, + 3 {:path ["SELECT" "JOIN"], :children [[:column "w" "id"] [:table "w"] [:column "v" "id"] [:table "v"]]}, + 7 {:path ["SELECT" "WHERE"], + :children [[:column "t" "id"] + [:table "t"] + [:column "u" "id"] + [:table "u"] + [:column "u" "id"] + [:table "u"] + [:column "v" "id"] + [:table "v"]]}}, + :parents {2 1, 4 3, 5 3, 6 3, 3 1, 7 1}, + :children {1 #{7 3 2}, 3 #{4 6 5}}, + :sequence [[1 [:column nil "x"]] + [2 [:table "t"]] + [4 [:table "u"]] + [5 [:table "v"]] + [6 [:table "w"]] + [3 [:column "w" "id"]] + [3 [:table "w"]] + [3 [:column "v" "id"]] + [3 [:table "v"]] + [7 [:column "t" "id"]] + [7 [:table "t"]] + [7 [:column "u" "id"]] + [7 [:table "u"]] + [7 [:column "u" "id"]] + [7 [:table "u"]] + [7 [:column "v" "id"]] + [7 [:table "v"]]]})) + + (is (= (mse/semantic-map "select t.a,b,c,d from t") + {:scopes {1 {:path ["SELECT"], + :children [[:column "t" "a"] [:table "t"] [:column nil "b"] [:column nil "c"] [:column nil "d"]]}, + 2 {:path ["SELECT" "FROM"], :children [[:table "t"]]}}, + :parents {2 1}, + :children {1 #{2}}, + :sequence [[1 [:column "t" "a"]] + [1 [:table "t"]] + [1 [:column nil "b"]] + [1 [:column nil "c"]] + [1 [:column nil "d"]] + [2 [:table "t"]]]}))) + +(deftest ^:parallel fields-to-search-test + ;; like source-columns, but understands scope + (is (= (mse/fields-to-search + (mse/fields->tables-in-scope "select x from t, u, v left join w on w.a = v.a where t.b = u.b and u.c = v.c limit 3")) + #{{:table "t" :column "b"} + {:table "t" :column "x"} + {:table "u" :column "b"} + {:table "u" :column "c"} + {:table "u" :column "x"} + {:table "v" :column "a"} + {:table "v" :column "c"} + {:table "v" :column "x"} + {:table "w" :column "a"} + {:table "w" :column "x"}})) + + (is (= (mse/fields-to-search + (mse/fields->tables-in-scope + "with b as (select x, * from a), + c as (select y, * from b) + select z from c;")) + ;; getting there - needs to unwrap cte aliases to the tables that they come from + #{{:table "a" :column "x"} + {:table "b" :column "y"} + {:table "c" :column "z"}})) + + (is (= (mse/fields-to-search + (mse/fields->tables-in-scope + "select x, y, (select z from u) from t")) + ;; totally loses x and y :-( + #{{:table "u" :column "z"}})))