From 0bb2c45fa56e4299538c9cd5c700ba6f5120d89a Mon Sep 17 00:00:00 2001 From: Firas Abuzaid Date: Thu, 8 Feb 2018 12:46:39 -0800 Subject: [PATCH] Incorporating Peter's changes to APriori --- .../analysis/summary/BatchSummarizer.java | 38 ++++++-- .../summary/aplinear/APLExplanation.java | 90 ++++++++++++++++--- .../aplinear/APLExplanationResult.java | 64 +++++++++---- .../aplinear/APLOutlierSummarizer.java | 34 ++++--- .../summary/aplinear/APLSummarizer.java | 4 +- .../summary/aplinear/APrioriLinear.java | 3 +- .../summary/ratios/ExplanationMetric.java | 21 ----- .../futuredata/macrobase/sql/QueryEngine.java | 88 ++++++++---------- .../macrobase/sql/tree/SingleColumn.java | 20 ++--- 9 files changed, 229 insertions(+), 133 deletions(-) diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/BatchSummarizer.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/BatchSummarizer.java index 0e13fe349..ebb9ee3d8 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/BatchSummarizer.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/BatchSummarizer.java @@ -2,28 +2,31 @@ import edu.stanford.futuredata.macrobase.datamodel.DataFrame; import edu.stanford.futuredata.macrobase.operator.Operator; - +import edu.stanford.futuredata.macrobase.util.MacrobaseException; import java.util.ArrayList; import java.util.List; /** - * Takes a dataframe with binary classification and searches for explanations - * (subgroup discovery / contrast set mining / feature selection) - * that capture differences between the two groups. + * Takes a dataframe with binary classification and searches for explanations (subgroup discovery / + * contrast set mining / feature selection) that capture differences between the two groups. * - * outlierColumn should either be 0.0 or 1.0 to signify outlying points or - * a count of the number of outliers represented by a row + * outlierColumn should either be 0.0 or 1.0 to signify outlying points or a count of the number of + * outliers represented by a row */ public abstract class BatchSummarizer implements Operator { + // Parameters protected String outlierColumn = "_OUTLIER"; protected double minOutlierSupport = 0.1; protected double minRatioMetric = 3; protected List attributes = new ArrayList<>(); protected int numThreads = Runtime.getRuntime().availableProcessors(); + protected String ratioMetric; + protected int maxOrder; /** * Adjust this to tune the significance (e.g. number of rows affected) of the results returned. + * * @param minSupport lowest outlier support of the results returned. */ public BatchSummarizer setMinSupport(double minSupport) { @@ -38,14 +41,17 @@ public BatchSummarizer setAttributes(List attributes) { /** * Set the column which indicates outlier status. "_OUTLIER" by default. + * * @param outlierColumn new outlier indicator column. */ public BatchSummarizer setOutlierColumn(String outlierColumn) { this.outlierColumn = outlierColumn; return this; } + /** * Adjust this to tune the severity (e.g. strength of correlation) of the results returned. + * * @param minRatioMetric lowest risk ratio to consider for meaningful explanations. */ @@ -56,7 +62,25 @@ public BatchSummarizer setMinRatioMetric(double minRatioMetric) { /** * The number of threads used in parallel summarizers. + * * @param numThreads Number of threads to use. */ - public void setNumThreads(int numThreads) { this.numThreads = numThreads; } + public BatchSummarizer setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + + public BatchSummarizer setRatioMetric(final String ratioMetric) { + this.ratioMetric = ratioMetric; + return this; + } + + public BatchSummarizer setMaxOrder(final int maxOrder) throws MacrobaseException { + if (maxOrder < 1 || maxOrder > 3) { + throw new MacrobaseException("Max Order " + maxOrder + + " cannot be less than 1 or greater than 3"); + } + this.maxOrder = maxOrder; + return this; + } } diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanation.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanation.java index 0984affd4..b9c37e873 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanation.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanation.java @@ -2,38 +2,37 @@ import com.fasterxml.jackson.annotation.JsonProperty; import edu.stanford.futuredata.macrobase.analysis.summary.Explanation; -import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.QualityMetric; import edu.stanford.futuredata.macrobase.analysis.summary.util.AttributeEncoder; - +import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.QualityMetric; +import edu.stanford.futuredata.macrobase.datamodel.DataFrame; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; public class APLExplanation implements Explanation { + private AttributeEncoder encoder; private List aggregateNames; private long numTotal; private long numOutliers; private ArrayList metrics; - private List thresholds; private ArrayList results; public APLExplanation( - AttributeEncoder encoder, - long numTotal, - long numOutliers, - List aggregateNames, - List metrics, - List thresholds, - List results + AttributeEncoder encoder, + long numTotal, + long numOutliers, + List aggregateNames, + List metrics, + List results ) { this.encoder = encoder; this.numTotal = numTotal; this.numOutliers = numOutliers; this.aggregateNames = aggregateNames; this.metrics = new ArrayList<>(metrics); - this.thresholds = new ArrayList<>(thresholds); this.results = new ArrayList<>(results); } @@ -63,14 +62,77 @@ public double numOutliers() { @Override public String prettyPrint() { StringBuilder header = new StringBuilder(String.format( - "Outlier Explanation:\n" + "Outlier Explanation:\n" )); - header.append("Outliers: "+numOutliers+", Total: "+numTotal+"\n"); + header.append("Outliers: " + numOutliers + ", Total: " + numTotal + "\n"); for (APLExplanationResult is : results) { header.append( - "---\n"+is.prettyPrint(encoder, aggregateNames) + "---\n" + is.prettyPrint(encoder, aggregateNames) ); } return header.toString(); } + + /** + * Convert List of {@link APLExplanationResult} to normalized DataFrame that includes all + * metrics contained in each results. + * + * @param attrsToInclude the attributes (String columns) to be included in the DataFrame + * @return New DataFrame with attrsToInclude columns and ratio metric, support, and + * outlier count columns + */ + public DataFrame toDataFrame(final List attrsToInclude) { + // String column values that will be added to DataFrame + final Map stringResultsByCol = new HashMap<>(); + for (String colName : attrsToInclude) { + stringResultsByCol.put(colName, new String[results.size()]); + } + + // double column values that will be added to the DataFrame + final Map doubleResultsByCol = new HashMap<>(); + for (String colName : aggregateNames) { + doubleResultsByCol.put(colName, new double[results.size()]); + } + for (QualityMetric metric : metrics) { + // NOTE: we assume that the QualityMetrics here are the same ones + // that each APLExplanationResult has + doubleResultsByCol.put(metric.name(), new double[results.size()]); + } + + // Add result rows to individual columns + int i = 0; + for (APLExplanationResult result : results) { + // attrValsInRow contains the values for the explanation attribute values in this + // given row + final Map attrValsInRow = result.prettyPrintMatch(encoder); + for (String colName : stringResultsByCol.keySet()) { + // Iterate over all attributes that will be in the DataFrame. + // If attribute is present in attrValsInRow, add its corresponding value. + // Otherwise, add null + stringResultsByCol.get(colName)[i] = attrValsInRow.get(colName); + } + + final Map metricValsInRow = result.getMetricsAsMap(); + for (String colName : doubleResultsByCol.keySet()) { + doubleResultsByCol.get(colName)[i] = metricValsInRow.get(colName); + } + + final Map aggregateValsInRow = result + .getAggregatesAsMap(aggregateNames); + for (String colName : doubleResultsByCol.keySet()) { + doubleResultsByCol.get(colName)[i] = aggregateValsInRow.get(colName); + } + ++i; + } + + // Generate DataFrame with results + final DataFrame df = new DataFrame(); + for (String attr : stringResultsByCol.keySet()) { + df.addColumn(attr, stringResultsByCol.get(attr)); + } + for (String attr : doubleResultsByCol.keySet()) { + df.addColumn(attr, doubleResultsByCol.get(attr)); + } + return df; + } } diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanationResult.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanationResult.java index 35c828a07..e3eb2811e 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanationResult.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLExplanationResult.java @@ -1,26 +1,29 @@ package edu.stanford.futuredata.macrobase.analysis.summary.aplinear; +import edu.stanford.futuredata.macrobase.analysis.summary.util.AttributeEncoder; import edu.stanford.futuredata.macrobase.analysis.summary.util.IntSet; -import edu.stanford.futuredata.macrobase.analysis.summary.util.IntSetAsLong; import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.QualityMetric; -import edu.stanford.futuredata.macrobase.analysis.summary.util.AttributeEncoder; - -import java.util.*; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; /** * Subgroup which meets the quality threshold metrics. */ public class APLExplanationResult { + private QualityMetric[] metricTypes; private IntSet matcher; private double[] aggregates; private double[] metrics; public APLExplanationResult( - QualityMetric[] metricTypes, - IntSet matcher, - double[] aggregates, - double[] metrics + QualityMetric[] metricTypes, + IntSet matcher, + double[] aggregates, + double[] metrics ) { this.metricTypes = metricTypes; this.matcher = matcher; @@ -28,7 +31,33 @@ public APLExplanationResult( this.metrics = metrics; } - private Map prettyPrintMatch(AttributeEncoder encoder) { + /** + * @return A Map with each metric value associated with the corresponding name of the metric + */ + Map getMetricsAsMap() { + final Map map = new HashMap<>(); + + for (int i = 0; i < metricTypes.length; i++) { + map.put(metricTypes[i].name(), metrics[i]); + } + return map; + } + + /** + * @param aggregateNames which aggregates to include in the Map + * @return A Map with each aggregate value associated with the corresponding name of the + * aggregate. + */ + Map getAggregatesAsMap(final List aggregateNames) { + final Map map = new HashMap<>(); + + for (int i = 0; i < aggregates.length; i++) { + map.put(aggregateNames.get(i), aggregates[i]); + } + return map; + } + + Map prettyPrintMatch(AttributeEncoder encoder) { Set values = matcher.getSet(); Map match = new HashMap<>(); @@ -57,7 +86,7 @@ private Map prettyPrintAggregate(List aggregateNames) { } public Map> jsonPrint(AttributeEncoder encoder, - List aggregateNames) { + List aggregateNames) { return new HashMap>() {{ put("matcher", prettyPrintMatch(encoder)); put("metric", prettyPrintMetric()); @@ -72,22 +101,23 @@ private String removeBrackets(String str) { } public String toString() { - return "a="+matcher.toString()+":ag="+Arrays.toString(aggregates)+":mt="+Arrays.toString(metrics); + return "a=" + matcher.toString() + ":ag=" + Arrays.toString(aggregates) + ":mt=" + Arrays + .toString(metrics); } public String prettyPrint( - AttributeEncoder encoder, - List aggregateNames + AttributeEncoder encoder, + List aggregateNames ) { String metricString = removeBrackets(prettyPrintMetric().toString()); String matchString = removeBrackets(prettyPrintMatch(encoder).toString()); String aggregateString = removeBrackets(prettyPrintAggregate(aggregateNames).toString()); return String.format( - "%s: %s\n%s: %s\n%s: %s\n", - "metrics", metricString, - "matches", matchString, - "aggregates", aggregateString + "%s: %s\n%s: %s\n%s: %s\n", + "metrics", metricString, + "matches", matchString, + "aggregates", aggregateString ); } } diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLOutlierSummarizer.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLOutlierSummarizer.java index 47ea19500..d5df13737 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLOutlierSummarizer.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLOutlierSummarizer.java @@ -2,20 +2,20 @@ import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.GlobalRatioQualityMetric; import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.QualityMetric; +import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.RiskRatioQualityMetric; import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.SupportQualityMetric; import edu.stanford.futuredata.macrobase.datamodel.DataFrame; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * Summarizer that works over both cube and row-based labeled ratio-based - * outlier summarization. + * Summarizer that works over both cube and row-based labeled ratio-based outlier summarization. */ public class APLOutlierSummarizer extends APLSummarizer { + private Logger log = LoggerFactory.getLogger("APLOutlierSummarizer"); private String countColumn = null; @@ -26,13 +26,14 @@ public List getAggregateNames() { @Override public int[][] getEncoded(List columns, DataFrame input) { - return encoder.encodeAttributesWithSupport(columns, minOutlierSupport, input.getDoubleColumnByName(outlierColumn)); + return encoder.encodeAttributesWithSupport(columns, minOutlierSupport, + input.getDoubleColumnByName(outlierColumn)); } @Override public double[][] getAggregateColumns(DataFrame input) { double[] outlierCol = input.getDoubleColumnByName(outlierColumn); - double[] countCol = processCountCol(input, countColumn, outlierCol.length); + double[] countCol = processCountCol(input, countColumn, outlierCol.length); double[][] aggregateColumns = new double[2][]; aggregateColumns[0] = outlierCol; @@ -45,11 +46,20 @@ public double[][] getAggregateColumns(DataFrame input) { public List getQualityMetricList() { List qualityMetricList = new ArrayList<>(); qualityMetricList.add( - new SupportQualityMetric(0) - ); - qualityMetricList.add( - new GlobalRatioQualityMetric(0, 1) + new SupportQualityMetric(0) ); + switch (ratioMetric) { + case "risk_ratio": + case "riskratio": + qualityMetricList.add( + new RiskRatioQualityMetric(0, 1)); + break; + case "global_ratio": + case "globalratio": + default: + qualityMetricList.add( + new GlobalRatioQualityMetric(0, 1)); + } return qualityMetricList; } @@ -71,9 +81,11 @@ public double getNumberOutliers(double[][] aggregates) { public String getCountColumn() { return countColumn; } + public void setCountColumn(String countColumn) { this.countColumn = countColumn; } + public double getMinRatioMetric() { return minRatioMetric; } diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLSummarizer.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLSummarizer.java index 324ab6604..24e1ad79b 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLSummarizer.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APLSummarizer.java @@ -48,6 +48,8 @@ protected double[] processCountCol(DataFrame input, String countColumn, int numR } return countCol; } + + public void process(DataFrame input) throws Exception { encoder = new AttributeEncoder(); encoder.setColumnNames(attributes); @@ -69,6 +71,7 @@ public void process(DataFrame input) throws Exception { List aplResults = aplKernel.explain(encoded, aggregateColumns, encoder.getNextKey(), + maxOrder, numThreads ); log.info("Number of results: {}", aplResults.size()); @@ -80,7 +83,6 @@ public void process(DataFrame input) throws Exception { numOutliers, aggregateNames, qualityMetricList, - thresholds, aplResults ); } diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APrioriLinear.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APrioriLinear.java index ae39aa892..f045ce682 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APrioriLinear.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/aplinear/APrioriLinear.java @@ -47,6 +47,7 @@ public List explain( final int[][] attributes, double[][] aggregateColumns, int cardinality, + final int maxOrder, int numThreads ) { final int numAggregates = aggregateColumns.length; @@ -97,7 +98,7 @@ public List explain( aRows[i][j] = aggregateColumns[j][i]; } } - for (int curOrder = 1; curOrder <= 3; curOrder++) { + for (int curOrder = 1; curOrder <= maxOrder; curOrder++) { long startTime = System.currentTimeMillis(); final int curOrderFinal = curOrder; // Initialize per-thread hashmaps. diff --git a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/ratios/ExplanationMetric.java b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/ratios/ExplanationMetric.java index 6991be6d5..f205b4b67 100644 --- a/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/ratios/ExplanationMetric.java +++ b/lib/src/main/java/edu/stanford/futuredata/macrobase/analysis/summary/ratios/ExplanationMetric.java @@ -1,7 +1,5 @@ package edu.stanford.futuredata.macrobase.analysis.summary.ratios; -import edu.stanford.futuredata.macrobase.util.MacrobaseException; - /** * Calculate generic metrics to quantify the severity of a classification result. * Can be extended in the future to also return confidence intervals. @@ -20,23 +18,4 @@ public abstract double calc( public String name() { return this.getClass().toString(); } - - /** - * @param metricName a String that maps to a particular subclass of ExplanationMetric - * @return an instantiation of the subclass. "risk_ratio/riskratio" -> {@link RiskRatioMetric}, - * "global_ratio/globalratio" -> {@link GlobalRatioMetric} - */ - public static ExplanationMetric getMetricFn(final String metricName) throws MacrobaseException { - switch (metricName.toLowerCase()) { - case "risk_ratio": - case "riskratio": - return new RiskRatioMetric(); - case "global_ratio": - case "globalratio": - return new GlobalRatioMetric(); - default: - throw new MacrobaseException(metricName + " is not a valid ExplanationMetric"); - - } - } } diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java index abe064b7b..acccaae91 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/QueryEngine.java @@ -7,8 +7,7 @@ import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import edu.stanford.futuredata.macrobase.analysis.MBFunction; -import edu.stanford.futuredata.macrobase.analysis.summary.apriori.APrioriSummarizer; -import edu.stanford.futuredata.macrobase.analysis.summary.ratios.ExplanationMetric; +import edu.stanford.futuredata.macrobase.analysis.summary.aplinear.APLOutlierSummarizer; import edu.stanford.futuredata.macrobase.datamodel.DataFrame; import edu.stanford.futuredata.macrobase.datamodel.Schema.ColType; import edu.stanford.futuredata.macrobase.ingest.CSVDataFrameParser; @@ -30,6 +29,7 @@ import edu.stanford.futuredata.macrobase.sql.tree.QueryBody; import edu.stanford.futuredata.macrobase.sql.tree.QuerySpecification; import edu.stanford.futuredata.macrobase.sql.tree.Relation; +import edu.stanford.futuredata.macrobase.sql.tree.Select; import edu.stanford.futuredata.macrobase.sql.tree.SelectItem; import edu.stanford.futuredata.macrobase.sql.tree.SingleColumn; import edu.stanford.futuredata.macrobase.sql.tree.SortItem; @@ -45,7 +45,6 @@ import java.util.BitSet; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -62,9 +61,11 @@ class QueryEngine { private static final Logger log = LoggerFactory.getLogger(QueryEngine.class.getSimpleName()); private final Map tablesInMemory; + private final int numThreads; QueryEngine() { tablesInMemory = new HashMap<>(); + numThreads = 1; // TODO: add configuration parameter for numThreads } /** @@ -171,20 +172,25 @@ private DataFrame executeDiffQuerySpec(final DiffQuerySpecification diffQuery) // TODO: if an explainCol isn't in the SELECT clause, don't include it final double minRatioMetric = diffQuery.getMinRatioExpression().getMinRatio(); final double minSupport = diffQuery.getMinSupportExpression().getMinSupport(); - final ExplanationMetric ratioMetric = ExplanationMetric - .getMetricFn(diffQuery.getRatioMetricExpr().getFuncName().toString()); + final String ratioMetric = diffQuery.getRatioMetricExpr().getFuncName().toString(); final int order = diffQuery.getMaxCombo().getValue(); // execute diff - final APrioriSummarizer summarizer = new APrioriSummarizer(); + final APLOutlierSummarizer summarizer = new APLOutlierSummarizer(); summarizer.setRatioMetric(ratioMetric) .setMaxOrder(order) .setMinSupport(minSupport) .setMinRatioMetric(minRatioMetric) .setOutlierColumn(outlierColName) - .setAttributes(explainCols); + .setAttributes(explainCols) + .setNumThreads(numThreads); - summarizer.process(dfToExplain); + try { + summarizer.process(dfToExplain); + } catch (Exception e) { + // TODO: get rid of this Exception + e.printStackTrace(); + } final DataFrame resultDf = summarizer.getResults().toDataFrame(explainCols); return evaluateSQLClauses(diffQuery, resultDf); @@ -215,23 +221,20 @@ private List findExplanationColumns(DataFrame dfToExplain) { } /** - * Removes all values in the SELECT clause of a given query that are {@link FunctionCall} - * objects, which are UDFs such as "percentile(column_name)". + * Returns all values in the SELECT clause of a given query that are {@link FunctionCall} + * objects, which are UDFs (e.g., "percentile(column_name)"). * - * @param selectItems the values in the SELECT clause to be modified in place - * @return The values that were removed from `selectItems`, returned as a List of {@link + * @param select The Select clause + * @return The items in the Select clause that correspond to UDFs returned as a List of {@link * SingleColumn} */ - private List removeUDFsInSelect(List selectItems) { + private List getUDFsInSelect(final Select select) { final List functionCalls = new ArrayList<>(); - Iterator it = selectItems.iterator(); - while (it.hasNext()) { - final SelectItem item = it.next(); + for (SelectItem item : select.getSelectItems()) { if (item instanceof SingleColumn) { final SingleColumn col = (SingleColumn) item; if (col.getExpression() instanceof FunctionCall) { functionCalls.add(col); - it.remove(); } } } @@ -270,28 +273,10 @@ private DataFrame concatOutliersAndInliers(final String outlierColName, */ private DataFrame evaluateSQLClauses(final QueryBody query, final DataFrame df) throws MacrobaseException { - // TODO: we need to figure out a smarter ordering of these. For example, - // if we have an ORDER BY, we don't need to sort columns that are never going to be in the - // final output (i.e. the ones not in the SELECT). Basically, we need to do two passes of - // SELECTs: one with all original projections + the columns in the WHERE clauses and ORDER BY - // clauses, and then a second with just the original projections. That should be correct - // and give us better performance. - - final List selectWithoutUdfs = Lists - .newArrayList(query.getSelect().getSelectItems()); - final List udfCols = removeUDFsInSelect(selectWithoutUdfs); - // selectWithoutUdfs has now been modified so that it no longer has UDFs - - // create shallow copy, so modifications don't persist on the original DataFrame - DataFrame resultDf = df.copy(); - final Map newColumns = evaluateUDFs(resultDf, udfCols); - - resultDf = evaluateWhereClause(df, query.getWhere()); - resultDf = evaluateSelectClause(resultDf, selectWithoutUdfs); - for (Map.Entry newColumn : newColumns.entrySet()) { - // add UDF columns to result - resultDf.addColumn((String) newColumn.getKey(), (double[]) newColumn.getValue()); - } + DataFrame resultDf = evaluateUDFs(df, getUDFsInSelect(query.getSelect())); + resultDf = evaluateWhereClause(resultDf, query.getWhere()); + resultDf = evaluateSelectClause(resultDf, query.getSelect()); + // TODO: what if you order by something that's not in the SELECT clause? resultDf = evaluateOrderByClause(resultDf, query.getOrderBy()); return evaluateLimitClause(resultDf, query.getLimit()); } @@ -347,10 +332,11 @@ private DataFrame getTable(String tableName) throws MacrobaseSQLException { * @param udfCols The List of UDFs to evaluate * @return The Map of new columns to be added */ - private Map evaluateUDFs(final DataFrame inputDf, - final List udfCols) + private DataFrame evaluateUDFs(final DataFrame inputDf, final List udfCols) throws MacrobaseException { - final Map newColumns = new HashMap<>(); + + // create shallow copy, so modifications don't persist on the original DataFrame + final DataFrame resultDf = inputDf.copy(); for (SingleColumn udfCol : udfCols) { final FunctionCall func = (FunctionCall) udfCol.getExpression(); // for now, if UDF is a.b.c.d(), ignore "a.b.c." @@ -360,9 +346,9 @@ private Map evaluateUDFs(final DataFrame inputDf, func.getArguments().stream().map(Expression::toString).findFirst().get()); // modify resultDf in place, add column; mbFunction is evaluated on input DataFrame - newColumns.put(udfCol.toString(), mbFunction.apply(inputDf)); + resultDf.addColumn(udfCol.toString(), mbFunction.apply(inputDf)); } - return newColumns; + return resultDf; } /** @@ -371,13 +357,16 @@ private Map evaluateUDFs(final DataFrame inputDf, * support for DISTINCT queries * * @param df The DataFrame to apply the Select clause on - * @param items The list of individual columns included in the Select clause + * @param select The Select clause * @return A new DataFrame with the result of the Select clause applied */ - private DataFrame evaluateSelectClause(DataFrame df, List items) { - if (items.size() == 1 && items.get(0) instanceof AllColumns) { - // SELECT * -> relation is unchanged - return df; + private DataFrame evaluateSelectClause(DataFrame df, final Select select) { + final List items = select.getSelectItems(); + for (SelectItem item : items) { + // If we find '*' -> relation is unchanged + if (item instanceof AllColumns) { + return df; + } } final List projections = items.stream().map(SelectItem::toString) .collect(toImmutableList()); @@ -393,6 +382,7 @@ private DataFrame evaluateSelectClause(DataFrame df, List items) { * clause * @return A new DataFrame with the result of the LIMIT clause applied */ + private DataFrame evaluateLimitClause(final DataFrame df, final Optional limitStr) { if (limitStr.isPresent()) { try { diff --git a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SingleColumn.java b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SingleColumn.java index 0367475b7..0036f93c8 100644 --- a/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SingleColumn.java +++ b/sql/src/main/java/edu/stanford/futuredata/macrobase/sql/tree/SingleColumn.java @@ -15,7 +15,6 @@ import static java.util.Objects.requireNonNull; -import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; @@ -80,22 +79,19 @@ public int hashCode() { @Override public String toString() { - // column name for the UDF is either 1) the user-provided alias, or 2) the function name - // and arguments concatenated by "_" - return alias.map(Identifier::toString).orElseGet(() -> formatForCol(expression)); + // column name for the UDF is either 1) the user-provided alias, or + // 2) the function name and arguments + return alias.map(Identifier::toString).orElseGet(() -> formatForColName(expression)); } /** - * @return If the Expression is a Function Call (e.g., a UDF), concatenate the function name and - * the arguments with "_". Otherwise, return the output of toString() + * @return If the Expression is a Function Call (e.g., a UDF), rewrite as "fn_name(arg1, arg2,… + * argn)". (By default, {@link FunctionCall#toString()} will include quotes around the function + * name.) Otherwise, return the output of toString() */ - private String formatForCol(final Expression expr) { + private String formatForColName(final Expression expr) { if (expr instanceof FunctionCall) { - final FunctionCall func = (FunctionCall) expr; - // for now, if UDF is a.b.c.d(), ignore "a.b.c." - final String funcName = func.getName().getSuffix(); - return funcName + "_" + Joiner.on("_") - .join(func.getArguments().stream().map(Expression::toString).iterator()); + return expr.toString().replaceAll("\"", ""); } return expr.toString();