diff --git a/sql/mobile_data_demo.sql b/sql/mobile_data_demo.sql index 85c26c949..da6fa70be 100644 --- a/sql/mobile_data_demo.sql +++ b/sql/mobile_data_demo.sql @@ -108,7 +108,6 @@ FROM DIFF ON state, hw_make, hw_model, app_version ORDER BY global_ratio; - -- Should be same as the original double-table query above SELECT app_version, hw_make, hw_model, global_ratio FROM DIFF @@ -116,6 +115,15 @@ FROM DIFF ON state, hw_make, hw_model, app_version ORDER BY global_ratio; +-- Demonstrate subquerying within SPLIT ON +SELECT app_version, hw_make, hw_model, global_ratio +FROM DIFF + (SPLIT ON PREDICATE(battery_drain, ">", 0.9) FROM ( + SELECT * FROM mobile_data where battery_drain > 0.25 + )) + ON state, hw_make, hw_model, app_version + ORDER BY global_ratio; + -- Should yield no results SELECT app_version, hw_make, hw_model, global_ratio FROM DIFF diff --git a/sql/src/main/antlr4/edu/stanford/futuredata/macrobase/SqlBase.g4 b/sql/src/main/antlr4/edu/stanford/futuredata/macrobase/SqlBase.g4 index b01f0fb75..61edd496c 100644 --- a/sql/src/main/antlr4/edu/stanford/futuredata/macrobase/SqlBase.g4 +++ b/sql/src/main/antlr4/edu/stanford/futuredata/macrobase/SqlBase.g4 @@ -117,6 +117,8 @@ diffQuerySpecification splitQuery : SPLIT ON identifier '(' (primaryExpression (',' primaryExpression)*)? ')' FROM relation + | SPLIT ON identifier '(' (primaryExpression (',' primaryExpression)*)? ')' + FROM queryTerm ; columnDefinition diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/ExpressionFormatter.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/ExpressionFormatter.java index 9b6e717d1..5afa26b52 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/ExpressionFormatter.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/ExpressionFormatter.java @@ -61,7 +61,7 @@ import edu.stanford.futuredata.macrobase.sql.tree.LambdaExpression; import edu.stanford.futuredata.macrobase.sql.tree.LikePredicate; import edu.stanford.futuredata.macrobase.sql.tree.LogicalBinaryExpression; -import edu.stanford.futuredata.macrobase.sql.tree.LongLiteral; +import edu.stanford.futuredata.macrobase.sql.tree.IntLiteral; import edu.stanford.futuredata.macrobase.sql.tree.Node; import edu.stanford.futuredata.macrobase.sql.tree.NotExpression; import edu.stanford.futuredata.macrobase.sql.tree.NullIfExpression; @@ -193,7 +193,7 @@ protected String visitSubscriptExpression(SubscriptExpression node, Void context } @Override - protected String visitLongLiteral(LongLiteral node, Void context) { + protected String visitLongLiteral(IntLiteral node, Void context) { return Long.toString(node.getValue()); } diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java index b95b0c26e..c0a8ae3e2 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java @@ -26,6 +26,7 @@ import edu.stanford.futuredata.macrobase.sql.tree.OrderBy; import edu.stanford.futuredata.macrobase.sql.tree.QueryBody; import edu.stanford.futuredata.macrobase.sql.tree.QuerySpecification; +import edu.stanford.futuredata.macrobase.sql.tree.Relation; import edu.stanford.futuredata.macrobase.sql.tree.Select; import edu.stanford.futuredata.macrobase.sql.tree.SelectItem; import edu.stanford.futuredata.macrobase.sql.tree.SortItem; @@ -94,7 +95,7 @@ private DataFrame executeDiffQuerySpec(final DiffQuerySpecification diffQuery) .collect(toImmutableList()); final String outlierColName = "outlier_col"; - DataFrame dfToQuery = null; + DataFrame dfToExplain = null; if (diffQuery.hasTwoArgs()) { final TableSubquery first = diffQuery.getFirst().get(); @@ -111,19 +112,24 @@ private DataFrame executeDiffQuerySpec(final DiffQuerySpecification diffQuery) "ON " + Joiner.on(", ").join(explainCols) + " not present in either" + " outlier or inlier subquery"); } - dfToQuery = combineOutliersAndInliers(outlierColName, outliersDf, inliersDf); + dfToExplain = combineOutliersAndInliers(outlierColName, outliersDf, inliersDf); } else { // splitQuery final SplitQuery splitQuery = diffQuery.getSplitQuery().get(); - Table table = (Table) splitQuery.getRelation(); - final String tableName = table.getName().toString(); - DataFrame df = getTable(tableName); + DataFrame inputDf; + final Relation inputRelation = splitQuery.getInputRelation(); + if (inputRelation instanceof TableSubquery) { + inputDf = executeQuery(((TableSubquery) inputRelation).getQuery().getQueryBody()); + } else { + // instance of Table + inputDf = getTable(((Table) inputRelation).getName().toString()); + } - if (!df.getSchema().hasColumns(explainCols)) { + if (!inputDf.getSchema().hasColumns(explainCols)) { throw new MacrobaseSQLException( "ON " + Joiner.on(", ").join(explainCols) + " not present in table " - + tableName); + + inputRelation); } final String classifierType = splitQuery.getClassifierName().getValue(); @@ -132,8 +138,8 @@ private DataFrame executeDiffQuerySpec(final DiffQuerySpecification diffQuery) .getClassifier(classifierType, splitQuery.getClassifierArgs().stream().map( Expression::toString).collect(toList())); classifier.setOutputColumnName(outlierColName); - classifier.process(df); - dfToQuery = classifier.getResults(); + classifier.process(inputDf); + dfToExplain = classifier.getResults(); } catch (MacrobaseException e) { // this comes from instantiating the classifier; re-throw @@ -144,17 +150,16 @@ private DataFrame executeDiffQuerySpec(final DiffQuerySpecification diffQuery) } } - // TODO: too many get's; too many fields are Optional that shouldn't be final double minRatioMetric = diffQuery.getMinRatioExpression().getMinRatio(); final double minSupport = diffQuery.getMinSupportExpression().getMinSupport(); final ExplanationMetric ratioMetric = ExplanationMetric .getMetricFn(diffQuery.getRatioMetricExpr().getFuncName().toString()); - final long order = diffQuery.getMaxCombo().getValue(); + final int order = diffQuery.getMaxCombo().getValue(); // execute diff // TODO: add support for "ON *" - DataFrame df = diff(dfToQuery, outlierColName, explainCols, minRatioMetric, minSupport, - ratioMetric, (int) order); + DataFrame df = diff(dfToExplain, outlierColName, explainCols, minRatioMetric, minSupport, + ratioMetric, order); return evaluateSQLClauses(diffQuery, df); } diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/parser/AstBuilder.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/parser/AstBuilder.java index bc5158e6d..db1f67357 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/parser/AstBuilder.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/parser/AstBuilder.java @@ -65,6 +65,7 @@ import edu.stanford.futuredata.macrobase.sql.tree.ImportCsv; import edu.stanford.futuredata.macrobase.sql.tree.InListExpression; import edu.stanford.futuredata.macrobase.sql.tree.InPredicate; +import edu.stanford.futuredata.macrobase.sql.tree.IntLiteral; import edu.stanford.futuredata.macrobase.sql.tree.Intersect; import edu.stanford.futuredata.macrobase.sql.tree.IntervalLiteral; import edu.stanford.futuredata.macrobase.sql.tree.IsNotNullPredicate; @@ -77,7 +78,6 @@ import edu.stanford.futuredata.macrobase.sql.tree.LambdaExpression; import edu.stanford.futuredata.macrobase.sql.tree.LikePredicate; import edu.stanford.futuredata.macrobase.sql.tree.LogicalBinaryExpression; -import edu.stanford.futuredata.macrobase.sql.tree.LongLiteral; import edu.stanford.futuredata.macrobase.sql.tree.MinRatioExpression; import edu.stanford.futuredata.macrobase.sql.tree.MinSupportExpression; import edu.stanford.futuredata.macrobase.sql.tree.NaturalJoin; @@ -261,8 +261,11 @@ public Node visitRatioMetricExpression(SqlBaseParser.RatioMetricExpressionContex public Node visitSplitQuery(SqlBaseParser.SplitQueryContext context) { Identifier classifierName = (Identifier) visit(context.identifier()); List classifierArgs = visit(context.primaryExpression(), Expression.class); - Relation relation = (Relation) visit(context.relation()); - return new SplitQuery(classifierName, classifierArgs, relation); + Optional relation = visitIfPresent(context.relation(), Relation.class); + Optional subquery = visitIfPresent(context.queryTerm(), TableSubquery.class); + check(relation.isPresent() || subquery.isPresent(), + "Either a relation or a subquery must be present in a SplitQuery", context); + return new SplitQuery(classifierName, classifierArgs, relation, subquery); } @Override @@ -276,7 +279,7 @@ public Node visitDiffQuerySpecification(SqlBaseParser.DiffQuerySpecificationCont check(subqueries.size() == 0 && splitQuery.isPresent() || subqueries.size() == 2 && !splitQuery .isPresent(), - "At least one and at most two relations required for diff query", context); + "At least one and at most two subqueries required for a DiffQuery", context); if (subqueries.size() == 2) { first = Optional.of(subqueries.get(0)); @@ -295,7 +298,7 @@ public Node visitDiffQuerySpecification(SqlBaseParser.DiffQuerySpecificationCont Identifier.class); check(attributeCols.size() > 0, "At least one attribute must be specified", context); - Optional maxCombo = getTextIfPresent(context.maxCombo).map(LongLiteral::new); + Optional maxCombo = getTextIfPresent(context.maxCombo).map(IntLiteral::new); Optional orderBy = Optional.empty(); if (context.ORDER() != null) { @@ -1088,7 +1091,7 @@ public Node visitTypeConstructor(SqlBaseParser.TypeConstructorContext context) { @Override public Node visitIntegerLiteral(SqlBaseParser.IntegerLiteralContext context) { - return new LongLiteral(getLocation(context), context.getText()); + return new IntLiteral(getLocation(context), context.getText()); } @Override diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/AstVisitor.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/AstVisitor.java index f0651cdd9..f8fbdf61b 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/AstVisitor.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/AstVisitor.java @@ -233,7 +233,7 @@ protected R visitSubscriptExpression(SubscriptExpression node, C context) { return visitExpression(node, context); } - protected R visitLongLiteral(LongLiteral node, C context) { + protected R visitLongLiteral(IntLiteral node, C context) { return visitLiteral(node, context); } diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/DiffQuerySpecification.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/DiffQuerySpecification.java index 72f029603..2f9a674ee 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/DiffQuerySpecification.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/DiffQuerySpecification.java @@ -21,14 +21,14 @@ public class DiffQuerySpecification extends QueryBody { private final MinRatioExpression minRatioExpr; private final MinSupportExpression minSupportExpr; private final RatioMetricExpression ratioMetricExpr; - private final LongLiteral maxCombo; + private final IntLiteral maxCombo; // Optional private final Optional where; private final Optional orderBy; private final Optional limit; private final Optional exportExpr; - private static final LongLiteral DEFAULT_MAX_COMBO = new LongLiteral("3"); + private static final IntLiteral DEFAULT_MAX_COMBO = new IntLiteral("3"); private static final MinRatioExpression DEFAULT_MIN_RATIO_EXPRESSION = new MinRatioExpression( new DecimalLiteral("1.5")); private static final MinSupportExpression DEFAULT_MIN_SUPPORT_EXPRESSION = new MinSupportExpression( @@ -46,7 +46,7 @@ public DiffQuerySpecification( Optional minRatioExpr, Optional minSupportExpr, Optional ratioMetricExpr, - Optional maxCombo, + Optional maxCombo, Optional where, Optional orderBy, Optional limit, @@ -65,7 +65,7 @@ public DiffQuerySpecification( Optional minRatioExpr, Optional minSupportExpr, Optional ratioMetricExpr, - Optional maxCombo, + Optional maxCombo, Optional where, Optional orderBy, Optional limit, @@ -84,7 +84,7 @@ private DiffQuerySpecification( Optional minRatioExpr, Optional minSupportExpr, Optional ratioMetricExpr, - Optional maxCombo, + Optional maxCombo, Optional where, Optional orderBy, Optional limit, @@ -155,7 +155,7 @@ public RatioMetricExpression getRatioMetricExpr() { return ratioMetricExpr; } - public LongLiteral getMaxCombo() { + public IntLiteral getMaxCombo() { return maxCombo; } @@ -190,7 +190,7 @@ public List getChildren() { nodes.add(minRatioExpr); nodes.add(minSupportExpr); nodes.add(ratioMetricExpr); - nodes.add(new LongLiteral("" + maxCombo)); + nodes.add(new IntLiteral("" + maxCombo)); where.ifPresent(nodes::add); orderBy.ifPresent(nodes::add); limit.ifPresent((str) -> nodes.add(new StringLiteral(str))); diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/LongLiteral.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/IntLiteral.java similarity index 81% rename from sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/LongLiteral.java rename to sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/IntLiteral.java index 0708248ca..064f3d079 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/LongLiteral.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/IntLiteral.java @@ -18,29 +18,29 @@ import edu.stanford.futuredata.macrobase.sql.parser.ParsingException; import java.util.Optional; -public class LongLiteral extends Literal { +public class IntLiteral extends Literal { - private final long value; + private final int value; - public LongLiteral(String value) { + public IntLiteral(String value) { this(Optional.empty(), value); } - public LongLiteral(NodeLocation location, String value) { + public IntLiteral(NodeLocation location, String value) { this(Optional.of(location), value); } - private LongLiteral(Optional location, String value) { + private IntLiteral(Optional location, String value) { super(location); requireNonNull(value, "value is null"); try { - this.value = Long.parseLong(value); + this.value = Integer.parseInt(value); } catch (NumberFormatException e) { throw new ParsingException("Invalid numeric literal: " + value); } } - public Long getValue() { + public int getValue() { return value; } @@ -58,7 +58,7 @@ public boolean equals(Object o) { return false; } - LongLiteral that = (LongLiteral) o; + IntLiteral that = (IntLiteral) o; return value == that.value; } diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SplitQuery.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SplitQuery.java index 322a61d16..dc33c5f09 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SplitQuery.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SplitQuery.java @@ -12,28 +12,33 @@ public class SplitQuery extends Node { private final Identifier classifierName; private final List classifierArgs; - private final Relation relation; + private final Optional relation; + private final Optional subquery; public SplitQuery(Identifier classifierName, List classifierArgs, - Relation relation) { - this(Optional.empty(), classifierName, classifierArgs, relation); + Optional relation, Optional subquery) { + this(Optional.empty(), classifierName, classifierArgs, relation, subquery); } public SplitQuery(NodeLocation location, Identifier classifierName, - List classifierArgs, Relation relation) { - this(Optional.of(location), classifierName, classifierArgs, relation); + List classifierArgs, Optional relation, + Optional subquery) { + this(Optional.of(location), classifierName, classifierArgs, relation, subquery); } private SplitQuery(Optional location, Identifier classifierName, - List classifierArgs, Relation relation) { + List classifierArgs, Optional relation, + Optional subquery) { super(location); requireNonNull(classifierName, "classifierName is null"); requireNonNull(classifierArgs, "classifierArgs is null"); requireNonNull(relation, "relation is null"); + requireNonNull(subquery, "subquery is null"); this.classifierName = classifierName; this.classifierArgs = classifierArgs; this.relation = relation; + this.subquery = subquery; } public Identifier getClassifierName() { @@ -44,8 +49,8 @@ public List getClassifierArgs() { return classifierArgs; } - public Relation getRelation() { - return relation; + public Relation getInputRelation() { + return relation.orElseGet(subquery::get); } @Override @@ -53,7 +58,8 @@ public List getChildren() { ImmutableList.Builder nodes = ImmutableList.builder(); nodes.add(classifierName); nodes.addAll(classifierArgs); - nodes.add(relation); + relation.ifPresent(nodes::add); + subquery.ifPresent(nodes::add); return nodes.build(); }