diff --git a/build.sh b/build.sh index 806336916..fd70b8853 100755 --- a/build.sh +++ b/build.sh @@ -2,4 +2,27 @@ set -e -cd lib && mvn clean && mvn install && cd ../sql && mvn clean && mvn package -DskipTests +build_module () { + pushd $1 + mvn clean && mvn package -DskipTests + popd +} + +pushd lib/ +mvn clean && mvn install -DskipTests +popd + +if [[ $# -eq 0 ]]; then + build_module core sql +else + while [[ $# -gt 0 ]] + do + if [ -e "$1"/pom.xml ]; then + build_module $1 + else + echo "$1 does not contain a module" + fi + shift # past argument + done +fi + diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/Classifier.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/Classifier.java index 526cdd19a..21a516658 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/Classifier.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/Classifier.java @@ -1,9 +1,11 @@ package edu.stanford.futuredata.macrobase.analysis.classify; import com.google.common.base.Joiner; +import edu.stanford.futuredata.macrobase.datamodel.DataFrame; import edu.stanford.futuredata.macrobase.operator.Transformer; import edu.stanford.futuredata.macrobase.util.MacrobaseException; import java.lang.reflect.InvocationTargetException; +import java.util.BitSet; import java.util.List; public abstract class Classifier implements Transformer { @@ -37,6 +39,8 @@ public Classifier setOutputColumnName(String outputColumnName) { return this; } + public abstract BitSet getMask(final DataFrame input); + public static Classifier getClassifier(String classifierType, List args) throws MacrobaseException { Class clazz; diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/PercentileClassifier.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/PercentileClassifier.java index 2c4c28136..0e7f4dd9d 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/PercentileClassifier.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/PercentileClassifier.java @@ -2,6 +2,7 @@ import edu.stanford.futuredata.macrobase.datamodel.DataFrame; import edu.stanford.futuredata.macrobase.util.MacrobaseException; +import java.util.BitSet; import java.util.List; import org.apache.commons.math3.stat.descriptive.rank.Percentile; @@ -34,7 +35,8 @@ public PercentileClassifier(String columnName) { */ public PercentileClassifier(List attrs) throws MacrobaseException { this(attrs.get(0)); - percentile = Double.parseDouble(attrs.get(1)); + percentile = 100 * (1 - Double + .parseDouble(attrs.get(1))); // TODO: this is stupid -- we need to standardize this includeHigh = (attrs.size() <= 2) || Boolean .parseBoolean(attrs.get(2)); // 3rd arg if present else true @@ -63,6 +65,25 @@ public void process(DataFrame input) { output.addColumn(outputColumnName, resultColumn); } + @Override + public BitSet getMask(DataFrame input) { + final double[] inputCol = input.getDoubleColumnByName(columnName); + final int numRows = inputCol.length; + lowCutoff = new Percentile().evaluate(inputCol, percentile); + highCutoff = new Percentile().evaluate(inputCol, 100.0 - percentile); + final BitSet mask = new BitSet(numRows); + + for (int i = 0; i < numRows; i++) { + double curVal = inputCol[i]; + if ((curVal > highCutoff && includeHigh) + || (curVal < lowCutoff && includeLow) + ) { + mask.set(i); + } + } + return mask; + } + @Override public DataFrame getResults() { return output; diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/PredicateClassifier.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/PredicateClassifier.java index d4b2835ba..af47412d2 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/PredicateClassifier.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/classify/PredicateClassifier.java @@ -3,24 +3,25 @@ import edu.stanford.futuredata.macrobase.analysis.classify.stats.MBPredicate; import edu.stanford.futuredata.macrobase.datamodel.DataFrame; import edu.stanford.futuredata.macrobase.util.MacrobaseException; - +import java.util.BitSet; import java.util.List; import java.util.function.DoublePredicate; import java.util.function.Predicate; /** - * PredicateClassifier classifies an outlier based on a predicate(e.g., equality, less than, greater than) - * and a hard-coded sentinel value. Unlike {@link PercentileClassifier}, outlier values are not determined based on a - * proportion of the values in the metric column. Instead, the outlier values are defined explicitly by the user in the - * conf.yaml file; for example: + * PredicateClassifier classifies an outlier based on a predicate(e.g., equality, less than, greater + * than) and a hard-coded sentinel value. Unlike {@link PercentileClassifier}, outlier values are + * not determined based on a proportion of the values in the metric column. Instead, the outlier + * values are defined explicitly by the user in the conf.yaml file; for example: * * classifier: "raw_threshold" * metric: "usage" * predicate: "==" * value: 1.0 * - * This would instantiate a PredicateClassifier that classifies every value in the "usage" column equal to 1.0 - * as an outlier. Currently, we support six different predicates: "==", "!=", "<", ">", "<=", and ">=". + * This would instantiate a PredicateClassifier that classifies every value in the "usage" column + * equal to 1.0 as an outlier. Currently, we support six different predicates: + * "==", "!=", "<", ">", "<=", and ">=". */ public class PredicateClassifier extends Classifier { @@ -34,10 +35,10 @@ public class PredicateClassifier extends Classifier { * @param columnName Column on which to classifier outliers * @param predicateStr Predicate used for classification: "==", "!=", "<", ">", "<=", or ">=" * @param sentinel Sentinel value used when evaluating the predicate to determine outlier - * @throws MacrobaseException */ - public PredicateClassifier(final String columnName, final String predicateStr, final double sentinel) - throws MacrobaseException { + public PredicateClassifier(final String columnName, final String predicateStr, + final double sentinel) + throws MacrobaseException { super(columnName); this.predicate = MBPredicate.getDoublePredicate(predicateStr, sentinel); this.isStrPredicate = false; @@ -48,10 +49,10 @@ public PredicateClassifier(final String columnName, final String predicateStr, f * @param columnName Column on which to classifier outliers * @param predicateStr Predicate used for classification: "==", "!=", "<", ">", "<=", or ">=" * @param sentinel Sentinel value used when evaluating the predicate to determine outlier - * @throws MacrobaseException */ - public PredicateClassifier(final String columnName, final String predicateStr, final String sentinel) - throws MacrobaseException { + public PredicateClassifier(final String columnName, final String predicateStr, + final String sentinel) + throws MacrobaseException { super(columnName); this.strPredicate = MBPredicate.getStrPredicate(predicateStr, sentinel); this.isStrPredicate = true; @@ -61,8 +62,8 @@ public PredicateClassifier(final String columnName, final String predicateStr, f * Alternate constructor that takes in List of Strings; used to instantiate Classifier (via * reflection) specified in MacroBase SQL query * - * @param attrs by convention, should be a List that has 3 values: [outlier_col_name, - * predicate type ("==","!=", etc.), sentinel (either String or double)] + * @param attrs by convention, should be a List that has 3 values: [outlier_col_name, predicate + * type ("==","!=", etc.), sentinel (either String or double)] */ public PredicateClassifier(final List attrs) throws MacrobaseException { super(attrs.get(0)); @@ -80,25 +81,54 @@ public PredicateClassifier(final List attrs) throws MacrobaseException { } } + @Override + public BitSet getMask(final DataFrame input) { + if (isStrPredicate) { + final String[] metrics = input.getStringColumnByName(columnName); + return getMask(metrics, strPredicate); + } else { + final double[] metrics = input.getDoubleColumnByName(columnName); + return getMask(metrics, predicate); + } + } + + private BitSet getMask(final double[] metrics, final DoublePredicate predicate) { + final int numRows = metrics.length; + final BitSet mask = new BitSet(numRows); + for (int i = 0; i < numRows; i++) { + if (predicate.test(metrics[i])) { + mask.set(i); + } + } + return mask; + } + private BitSet getMask(final String[] metrics, final Predicate predicate) { + final int numRows = metrics.length; + final BitSet mask = new BitSet(numRows); + for (int i = 0; i < numRows; i++) { + if (predicate.test(metrics[i])) { + mask.set(i); + } + } + return mask; + } /** - * Scan through the metric column, and evaluate the predicate on every value in the column. The ``input'' DataFrame - * remains unmodified; a copy is created and all modifications are made on the copy. - * @throws Exception + * Scan through the metric column, and evaluate the predicate on every value in the column. The + * ``input'' DataFrame remains unmodified; a copy is created and all modifications are made on + * the copy. */ @Override public void process(DataFrame input) throws Exception { if (isStrPredicate) { processString(input); - } - else { + } else { processDouble(input); } } - - public void processDouble(DataFrame input) throws Exception { + private void processDouble(DataFrame input) throws Exception { double[] metrics = input.getDoubleColumnByName(columnName); int len = metrics.length; output = input.copy(); @@ -114,8 +144,7 @@ public void processDouble(DataFrame input) throws Exception { output.addColumn(outputColumnName, resultColumn); } - - public void processString(DataFrame input) throws Exception { + private void processString(DataFrame input) throws Exception { String[] metrics = input.getStringColumnByName(columnName); int len = metrics.length; output = input.copy(); @@ -136,4 +165,5 @@ public void processString(DataFrame input) throws Exception { public DataFrame getResults() { return output; } + } diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/datamodel/DataFrame.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/datamodel/DataFrame.java index 88c5df4e9..f1d614832 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/datamodel/DataFrame.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/datamodel/DataFrame.java @@ -30,8 +30,10 @@ * initialized from a schema and a set of rows. */ public class DataFrame { - private Schema schema; + private static final int MAX_COLS_FOR_TABULAR_PRINT = 10; + + private Schema schema; private ArrayList stringCols; private ArrayList doubleCols; // external indices define a global ordering on columns, but internally each @@ -142,38 +144,52 @@ public String toString() { * | val_m1 | val_m2 | ... | val_mn | * ------------------------------------------ * @param out PrintStream to write to STDOUT or file (default: STDOUT) - * @param maxNumToPrint maximum number of rows from the DataFrame to print (default: 15) + * @param maxNumToPrint maximum number of rows from the DataFrame to print (default: -1, i.e., + * all rows) */ public void prettyPrint(final PrintStream out, final int maxNumToPrint) { out.println(numRows + (numRows == 1 ? " row" : " rows")); + out.println(); final int maxColNameLength = schema.getColumnNames().stream() - .reduce("", (x, y) -> x.length() > y.length() ? x : y).length() + 4; // 2 extra spaces on both sides - final String schemaStr = "|" + Joiner.on("|").join(schema.getColumnNames().stream() - .map((x) -> StringUtils.center(String.valueOf(x), maxColNameLength)).collect(toList())) + "|"; - final String dashes = Joiner.on("").join(Collections.nCopies(schemaStr.length(), "-")); - out.println(dashes); - out.println(schemaStr); - out.println(dashes); - - if (numRows > maxNumToPrint) { - final int numToPrint = maxNumToPrint / 2; - for (Row r : getRows(0, numToPrint)) { - r.prettyPrint(out, maxColNameLength); - } - out.println(); - out.println("..."); - out.println(); - for (Row r : getRows(numRows - numToPrint, numRows)) { - r.prettyPrint(out, maxColNameLength); + .reduce("", (x, y) -> x.length() > y.length() ? x : y).length(); + + if (schema.getNumColumns() > MAX_COLS_FOR_TABULAR_PRINT) { + // print each row so that each value is on a separate line + for (Row r : getRows()) { + r.prettyPrintColumnWise(out, maxColNameLength); } } else { - for (Row r : getRows()) { - r.prettyPrint(out, maxColNameLength); + // print DataFrame as a table + final int tableWidth = + maxColNameLength + 4; // 2 extra spaces on both sides of each column name and value + final List colStrs = schema.getColumnNames().stream() + .map((x) -> StringUtils.center(String.valueOf(x), tableWidth)).collect(toList()); + final String schemaStr = "|" + Joiner.on("|").join(colStrs) + "|"; + final String dashes = Joiner.on("").join(Collections.nCopies(schemaStr.length(), "-")); + out.println(dashes); + out.println(schemaStr); + out.println(dashes); + + if (maxNumToPrint > 0 && numRows > maxNumToPrint) { + final int numToPrint = maxNumToPrint / 2; + for (Row r : getRows(0, numToPrint)) { + r.prettyPrint(out, tableWidth); + } + out.println(); + out.println("..."); + out.println(); + for (Row r : getRows(numRows - numToPrint, numRows)) { + r.prettyPrint(out, tableWidth); + } + } else { + for (Row r : getRows()) { + r.prettyPrint(out, tableWidth); + } } + out.println(dashes); + out.println(); } - out.println(dashes); - out.println(); } /** diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/datamodel/Row.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/datamodel/Row.java index b7007c0f2..12ba5f41e 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/datamodel/Row.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/datamodel/Row.java @@ -5,6 +5,7 @@ import com.google.common.base.Joiner; import java.io.PrintStream; import java.text.DecimalFormat; +import java.util.Collections; import java.util.List; import org.apache.commons.lang3.StringUtils; @@ -12,6 +13,7 @@ * Format for import / export small batches */ public class Row { + // Formatter for printing out doubles; print at least 1 and no more than 6 decimal places private static final DecimalFormat DOUBLE_FORMAT = new DecimalFormat("#.0#####"); @@ -22,6 +24,7 @@ public Row(Schema schema, List vals) { this.schema = schema; this.vals = vals; } + public Row(List vals) { this.schema = null; this.vals = vals; @@ -33,7 +36,7 @@ public List getVals() { @SuppressWarnings("unchecked") public T getAs(int i) { - return (T)vals.get(i); + return (T) vals.get(i); } @SuppressWarnings("unchecked") @@ -41,14 +44,18 @@ public T getAs(String colName) { if (schema == null) { throw new RuntimeException("No Schema"); } else { - return (T)vals.get(schema.getColumnIndex(colName)); + return (T) vals.get(schema.getColumnIndex(colName)); } } @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } Row row = (Row) o; return vals != null ? vals.equals(row.vals) : row.vals == null; @@ -67,36 +74,34 @@ public String toString() { /** * pretty print Row object to STDOUT or file (default: STDOUT), using a default width of 15 - * characters per value. Example output: - * | val_1 | val_2 | .... | val_n | + * characters per value. Example output: | val_1 | val_2 | .... | val_n | */ public void prettyPrint() { prettyPrint(System.out, 15); } /** - * pretty print Row object to out using a default width of 15 - * characters per value. Example output: - * | val_1 | val_2 | .... | val_n | + * pretty print Row object to out using a default width of 15 characters per value. + * Example output: | val_1 | val_2 | .... | val_n | */ public void prettyPrint(final PrintStream out) { prettyPrint(out, 15); } /** - * pretty print Row object to STDOUT - * Example output: - * | val_1 | val_2 | .... | val_n | - * @param width number of characters to or each value, with (width - length of value) / 2 of - * whitespace on either side + * pretty print Row object to STDOUT Example output: | val_1 | val_2 | .... | val_n + * | + * + * @param width number of characters to or each value, with (width - length of value) / + * 2 of whitespace on either side */ public void prettyPrint(final int width) { prettyPrint(System.out, width); } /** - * pretty print Row object to the console. Example output: - * | val_1 | val_2 | .... | val_n | + * pretty print Row object to the console. Example output: | val_1 | val_2 | .... | + * val_n | * * @param out PrintStream to print Row to STDOUT or file (default: STDOUT) * @param width the number of characters to use for centering a single value. Increasing @@ -108,10 +113,26 @@ public void prettyPrint(final PrintStream out, final int width) { .collect(toList())) + "|"); } + public void prettyPrintColumnWise(final PrintStream out, final int maxColNameLength) { + int maxLength = 0; + for (int i = 0; i < schema.getNumColumns(); ++i) { + final String colName = schema.getColumnName(i); + final Object val = vals.get(i); + final String strToPrint = + StringUtils.rightPad(colName, maxColNameLength) + " | " + formatVal(val, 40); + if (strToPrint.length() > maxLength) { + maxLength = strToPrint.length(); + } + out.println(strToPrint); + } + final String dashes = Joiner.on("").join(Collections.nCopies(maxLength + 5, "-")); + out.println(dashes); + } + /** * @return If x is a double, return back a formatted String that prints at least 1 and up to 6 * decimal places of the double. If x is null, return "-". Otherwise, return x unchanged (i.e. - * toString will be used), but truncate it to @param length + * toString will be used), but truncate it to @param width, if greater than 0 */ private String formatVal(Object x, final int width) { if (x == null) { @@ -123,7 +144,7 @@ private String formatVal(Object x, final int width) { return DOUBLE_FORMAT.format(x); } else { final String str = String.valueOf(x); - if (str.length() > width) { + if (width > 0 && str.length() > width) { return str.substring(0, width - 3) + "..."; } else { return str; diff --git a/lib/src/test/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanationTest.java b/lib/src/test/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanationTest.java new file mode 100644 index 000000000..fdd676e4b --- /dev/null +++ b/lib/src/test/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanationTest.java @@ -0,0 +1,18 @@ +package edu.stanford.futuredata.macrobase.analysis.summary.aplinear; + +import static org.junit.Assert.*; + +import org.junit.After; +import org.junit.Before; + +public class APLExplanationTest { + + @Before + public void setUp() throws Exception { + } + + @After + public void tearDown() throws Exception { + } + +} \ No newline at end of file diff --git a/sql/pom.xml b/sql/pom.xml index 36e6b7e56..34e45d356 100644 --- a/sql/pom.xml +++ b/sql/pom.xml @@ -52,7 +52,7 @@ jline jline - 2.14.2 + 2.14.5 diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/MacroBaseSQLRepl.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/MacroBaseSQLRepl.java index 0381caae5..2bfc23652 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/MacroBaseSQLRepl.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/MacroBaseSQLRepl.java @@ -16,9 +16,11 @@ import edu.stanford.futuredata.macrobase.sql.tree.Statement; import edu.stanford.futuredata.macrobase.util.MacrobaseException; import java.io.File; +import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; +import java.io.PrintStream; import java.nio.file.Paths; import jline.console.ConsoleReader; import jline.console.completer.CandidateListCompletionHandler; @@ -57,13 +59,13 @@ private MacroBaseSQLRepl() throws IOException { * * @param queries A single String which contains the queries to execute. Each query in the * String should be delimited by ';' and optional whitespace. - * @param print If True, print query to the console (useful when reading queries from file) + * @param fromFile If True, queries have been read from File */ - private void executeQueries(final String queries, final boolean print) { + private void executeQueries(final String queries, final boolean fromFile) { StatementSplitter splitter = new StatementSplitter(queries); for (StatementSplitter.Statement s : splitter.getCompleteStatements()) { final String statementStr = s.statement(); - if (print) { + if (fromFile) { System.out.println(statementStr + ";"); System.out.println(); System.out.flush(); @@ -73,13 +75,26 @@ private void executeQueries(final String queries, final boolean print) { try { Statement stmt = parser.createStatement(statementStr); log.debug(stmt.toString()); + final DataFrame result; if (stmt instanceof ImportCsv) { final ImportCsv importStatement = (ImportCsv) stmt; - queryEngine.importTableFromCsv(importStatement).prettyPrint(); + result = queryEngine.importTableFromCsv(importStatement); } else { - QueryBody q = ((Query) stmt).getQueryBody(); - final DataFrame result = queryEngine.executeQuery(q); - result.prettyPrint(); + final QueryBody q = ((Query) stmt).getQueryBody(); + result = queryEngine.executeQuery(q); + } + try { + final PrintStream ps = new PrintStream(new FileOutputStream("/tmp/mb-sql.output")); + result.prettyPrint(ps, -1); + ProcessBuilder pb = new ProcessBuilder("less", "/tmp/mb-sql.output"); + pb.inheritIO(); + Process p = pb.start(); + p.waitFor(); + } catch (InterruptedException | IOException e) { + e.printStackTrace(); + } + if (stmt instanceof Query) { + final QueryBody q = ((Query) stmt).getQueryBody(); q.getExportExpr().ifPresent((exportExpr) -> { // print result to file; if file already exists, do nothing and print error message final String filename = exportExpr.getFilename(); @@ -183,7 +198,6 @@ public static void main(String... args) throws IOException { if (!printedWelcome) { System.out.println(asciiArt); } - repl.runRepl(); } diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java index acccaae91..24ffc171a 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import edu.stanford.futuredata.macrobase.analysis.MBFunction; +import edu.stanford.futuredata.macrobase.analysis.classify.Classifier; import edu.stanford.futuredata.macrobase.analysis.summary.aplinear.APLOutlierSummarizer; import edu.stanford.futuredata.macrobase.datamodel.DataFrame; import edu.stanford.futuredata.macrobase.datamodel.Schema.ColType; @@ -52,7 +53,9 @@ import java.util.Set; import java.util.function.DoublePredicate; import java.util.function.Predicate; +import java.util.stream.Collectors; import java.util.stream.DoubleStream; +import java.util.stream.Stream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -229,16 +232,16 @@ private List findExplanationColumns(DataFrame dfToExplain) { * SingleColumn} */ private List getUDFsInSelect(final Select select) { - final List functionCalls = new ArrayList<>(); + final List udfs = new ArrayList<>(); for (SelectItem item : select.getSelectItems()) { if (item instanceof SingleColumn) { final SingleColumn col = (SingleColumn) item; if (col.getExpression() instanceof FunctionCall) { - functionCalls.add(col); + udfs.add(col); } } } - return functionCalls; + return udfs; } /** @@ -313,7 +316,8 @@ private DataFrame executeQuerySpec(final QuerySpecification query) * Get table as DataFrame that has previously been loaded into memory * * @param tableName String that uniquely identifies table - * @return DataFrame for table + * @return a shallow copy of the DataFrame for table; the original DataFrame is never returned, + * so that we keep it immutable * @throws MacrobaseSQLException if the table has not been loaded into memory and does not * exist */ @@ -321,16 +325,16 @@ private DataFrame getTable(String tableName) throws MacrobaseSQLException { if (!tablesInMemory.containsKey(tableName)) { throw new MacrobaseSQLException("Table " + tableName + " does not exist"); } - return tablesInMemory.get(tableName); + return tablesInMemory.get(tableName).copy(); } /** - * Evaluate only the UDFs of SQL query and return a Map of column names -> double arrays. If - * there are no UDFs (i.e. @param udfCols is empty), an empty Map is returned. + * Evaluate only the UDFs of SQL query and return a new DataFrame with the UDF-generated columns + * added to the input DataFrame. If there are no UDFs (i.e. @param udfCols is empty), the input + * DataFrame is returned as is. * * @param inputDf The DataFrame to evaluate the UDFs on * @param udfCols The List of UDFs to evaluate - * @return The Map of new columns to be added */ private DataFrame evaluateUDFs(final DataFrame inputDf, final List udfCols) throws MacrobaseException { @@ -405,7 +409,7 @@ private DataFrame evaluateLimitClause(final DataFrame df, final Optional * whereClauseOpt is not Present, we return df */ private DataFrame evaluateWhereClause(final DataFrame df, - final Optional whereClauseOpt) throws MacrobaseSQLException { + final Optional whereClauseOpt) throws MacrobaseException { if (!whereClauseOpt.isPresent()) { return df; } @@ -423,7 +427,7 @@ private DataFrame evaluateWhereClause(final DataFrame df, * @throws MacrobaseSQLException Only comparison expressions (e.g., WHERE x = 42) and logical * AND/OR/NOT combinations of such expressions are supported; exception is thrown otherwise. */ - private BitSet getMask(DataFrame df, Expression whereClause) throws MacrobaseSQLException { + private BitSet getMask(DataFrame df, Expression whereClause) throws MacrobaseException { if (whereClause instanceof NotExpression) { final NotExpression notExpr = (NotExpression) whereClause; final BitSet mask = getMask(df, notExpr.getValue()); @@ -459,11 +463,24 @@ private BitSet getMask(DataFrame df, Expression whereClause) throws MacrobaseSQL return maskForPredicate(df, (Literal) left, (Identifier) right, type); } else if (right instanceof Literal && left instanceof Identifier) { return maskForPredicate(df, (Literal) right, (Identifier) left, type); + } else if (left instanceof FunctionCall && right instanceof Literal) { + return maskForPredicate(df, (FunctionCall) left, (Literal) right); + } else if (right instanceof FunctionCall && left instanceof Literal) { + return maskForPredicate(df, (FunctionCall) right, (Literal) left); } } throw new MacrobaseSQLException("Boolean expression not supported"); } + private BitSet maskForPredicate(DataFrame df, FunctionCall func, Literal val) + throws MacrobaseException { + final String funcName = func.getName().getSuffix(); + final Classifier classifier = Classifier.getClassifier(funcName, + Stream.concat(func.getArguments().stream().map(Expression::toString), + Stream.of(val.toString())).collect(Collectors.toList())); + return classifier.getMask(df); + } + /** * The base case for {@link QueryEngine#getMask(DataFrame, Expression)}; returns a boolean mask diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SingleColumn.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SingleColumn.java index 0036f93c8..38d536a32 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SingleColumn.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SingleColumn.java @@ -59,6 +59,10 @@ public Expression getExpression() { return expression; } + public boolean isUDF() { + return expression instanceof FunctionCall; + } + @Override public boolean equals(Object obj) { if (this == obj) {