Skip to content

Commit

Permalink
Incorporating Peter's changes to APriori
Browse files Browse the repository at this point in the history
  • Loading branch information
fabuzaid21 committed Feb 8, 2018
1 parent 8f5fdff commit 0bb2c45
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,31 @@

import edu.stanford.futuredata.macrobase.datamodel.DataFrame;
import edu.stanford.futuredata.macrobase.operator.Operator;

import edu.stanford.futuredata.macrobase.util.MacrobaseException;
import java.util.ArrayList;
import java.util.List;

/**
* Takes a dataframe with binary classification and searches for explanations
* (subgroup discovery / contrast set mining / feature selection)
* that capture differences between the two groups.
* Takes a dataframe with binary classification and searches for explanations (subgroup discovery /
* contrast set mining / feature selection) that capture differences between the two groups.
*
* outlierColumn should either be 0.0 or 1.0 to signify outlying points or
* a count of the number of outliers represented by a row
* outlierColumn should either be 0.0 or 1.0 to signify outlying points or a count of the number of
* outliers represented by a row
*/
public abstract class BatchSummarizer implements Operator<DataFrame, Explanation> {

// Parameters
protected String outlierColumn = "_OUTLIER";
protected double minOutlierSupport = 0.1;
protected double minRatioMetric = 3;
protected List<String> attributes = new ArrayList<>();
protected int numThreads = Runtime.getRuntime().availableProcessors();
protected String ratioMetric;
protected int maxOrder;

/**
* Adjust this to tune the significance (e.g. number of rows affected) of the results returned.
*
* @param minSupport lowest outlier support of the results returned.
*/
public BatchSummarizer setMinSupport(double minSupport) {
Expand All @@ -38,14 +41,17 @@ public BatchSummarizer setAttributes(List<String> attributes) {

/**
* Set the column which indicates outlier status. "_OUTLIER" by default.
*
* @param outlierColumn new outlier indicator column.
*/
public BatchSummarizer setOutlierColumn(String outlierColumn) {
this.outlierColumn = outlierColumn;
return this;
}

/**
* Adjust this to tune the severity (e.g. strength of correlation) of the results returned.
*
* @param minRatioMetric lowest risk ratio to consider for meaningful explanations.
*/

Expand All @@ -56,7 +62,25 @@ public BatchSummarizer setMinRatioMetric(double minRatioMetric) {

/**
* The number of threads used in parallel summarizers.
*
* @param numThreads Number of threads to use.
*/
public void setNumThreads(int numThreads) { this.numThreads = numThreads; }
public BatchSummarizer setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
}

public BatchSummarizer setRatioMetric(final String ratioMetric) {
this.ratioMetric = ratioMetric;
return this;
}

public BatchSummarizer setMaxOrder(final int maxOrder) throws MacrobaseException {
if (maxOrder < 1 || maxOrder > 3) {
throw new MacrobaseException("Max Order " + maxOrder +
" cannot be less than 1 or greater than 3");
}
this.maxOrder = maxOrder;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,37 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import edu.stanford.futuredata.macrobase.analysis.summary.Explanation;
import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.QualityMetric;
import edu.stanford.futuredata.macrobase.analysis.summary.util.AttributeEncoder;

import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.QualityMetric;
import edu.stanford.futuredata.macrobase.datamodel.DataFrame;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class APLExplanation implements Explanation {

private AttributeEncoder encoder;
private List<String> aggregateNames;
private long numTotal;
private long numOutliers;

private ArrayList<QualityMetric> metrics;
private List<Double> thresholds;
private ArrayList<APLExplanationResult> results;

public APLExplanation(
AttributeEncoder encoder,
long numTotal,
long numOutliers,
List<String> aggregateNames,
List<QualityMetric> metrics,
List<Double> thresholds,
List<APLExplanationResult> results
AttributeEncoder encoder,
long numTotal,
long numOutliers,
List<String> aggregateNames,
List<QualityMetric> metrics,
List<APLExplanationResult> results
) {
this.encoder = encoder;
this.numTotal = numTotal;
this.numOutliers = numOutliers;
this.aggregateNames = aggregateNames;
this.metrics = new ArrayList<>(metrics);
this.thresholds = new ArrayList<>(thresholds);
this.results = new ArrayList<>(results);
}

Expand Down Expand Up @@ -63,14 +62,77 @@ public double numOutliers() {
@Override
public String prettyPrint() {
StringBuilder header = new StringBuilder(String.format(
"Outlier Explanation:\n"
"Outlier Explanation:\n"
));
header.append("Outliers: "+numOutliers+", Total: "+numTotal+"\n");
header.append("Outliers: " + numOutliers + ", Total: " + numTotal + "\n");
for (APLExplanationResult is : results) {
header.append(
"---\n"+is.prettyPrint(encoder, aggregateNames)
"---\n" + is.prettyPrint(encoder, aggregateNames)
);
}
return header.toString();
}

/**
* Convert List of {@link APLExplanationResult} to normalized DataFrame that includes all
* metrics contained in each results.
*
* @param attrsToInclude the attributes (String columns) to be included in the DataFrame
* @return New DataFrame with <tt>attrsToInclude</tt> columns and ratio metric, support, and
* outlier count columns
*/
public DataFrame toDataFrame(final List<String> attrsToInclude) {
// String column values that will be added to DataFrame
final Map<String, String[]> stringResultsByCol = new HashMap<>();
for (String colName : attrsToInclude) {
stringResultsByCol.put(colName, new String[results.size()]);
}

// double column values that will be added to the DataFrame
final Map<String, double[]> doubleResultsByCol = new HashMap<>();
for (String colName : aggregateNames) {
doubleResultsByCol.put(colName, new double[results.size()]);
}
for (QualityMetric metric : metrics) {
// NOTE: we assume that the QualityMetrics here are the same ones
// that each APLExplanationResult has
doubleResultsByCol.put(metric.name(), new double[results.size()]);
}

// Add result rows to individual columns
int i = 0;
for (APLExplanationResult result : results) {
// attrValsInRow contains the values for the explanation attribute values in this
// given row
final Map<String, String> attrValsInRow = result.prettyPrintMatch(encoder);
for (String colName : stringResultsByCol.keySet()) {
// Iterate over all attributes that will be in the DataFrame.
// If attribute is present in attrValsInRow, add its corresponding value.
// Otherwise, add null
stringResultsByCol.get(colName)[i] = attrValsInRow.get(colName);
}

final Map<String, Double> metricValsInRow = result.getMetricsAsMap();
for (String colName : doubleResultsByCol.keySet()) {
doubleResultsByCol.get(colName)[i] = metricValsInRow.get(colName);
}

final Map<String, Double> aggregateValsInRow = result
.getAggregatesAsMap(aggregateNames);
for (String colName : doubleResultsByCol.keySet()) {
doubleResultsByCol.get(colName)[i] = aggregateValsInRow.get(colName);
}
++i;
}

// Generate DataFrame with results
final DataFrame df = new DataFrame();
for (String attr : stringResultsByCol.keySet()) {
df.addColumn(attr, stringResultsByCol.get(attr));
}
for (String attr : doubleResultsByCol.keySet()) {
df.addColumn(attr, doubleResultsByCol.get(attr));
}
return df;
}
}
Original file line number Diff line number Diff line change
@@ -1,34 +1,63 @@
package edu.stanford.futuredata.macrobase.analysis.summary.aplinear;

import edu.stanford.futuredata.macrobase.analysis.summary.util.AttributeEncoder;
import edu.stanford.futuredata.macrobase.analysis.summary.util.IntSet;
import edu.stanford.futuredata.macrobase.analysis.summary.util.IntSetAsLong;
import edu.stanford.futuredata.macrobase.analysis.summary.util.qualitymetrics.QualityMetric;
import edu.stanford.futuredata.macrobase.analysis.summary.util.AttributeEncoder;

import java.util.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Subgroup which meets the quality threshold metrics.
*/
public class APLExplanationResult {

private QualityMetric[] metricTypes;
private IntSet matcher;
private double[] aggregates;
private double[] metrics;

public APLExplanationResult(
QualityMetric[] metricTypes,
IntSet matcher,
double[] aggregates,
double[] metrics
QualityMetric[] metricTypes,
IntSet matcher,
double[] aggregates,
double[] metrics
) {
this.metricTypes = metricTypes;
this.matcher = matcher;
this.aggregates = aggregates;
this.metrics = metrics;
}

private Map<String, String> prettyPrintMatch(AttributeEncoder encoder) {
/**
* @return A Map with each metric value associated with the corresponding name of the metric
*/
Map<String, Double> getMetricsAsMap() {
final Map<String, Double> map = new HashMap<>();

for (int i = 0; i < metricTypes.length; i++) {
map.put(metricTypes[i].name(), metrics[i]);
}
return map;
}

/**
* @param aggregateNames which aggregates to include in the Map
* @return A Map with each aggregate value associated with the corresponding name of the
* aggregate.
*/
Map<String, Double> getAggregatesAsMap(final List<String> aggregateNames) {
final Map<String, Double> map = new HashMap<>();

for (int i = 0; i < aggregates.length; i++) {
map.put(aggregateNames.get(i), aggregates[i]);
}
return map;
}

Map<String, String> prettyPrintMatch(AttributeEncoder encoder) {
Set<Integer> values = matcher.getSet();
Map<String, String> match = new HashMap<>();

Expand Down Expand Up @@ -57,7 +86,7 @@ private Map<String, String> prettyPrintAggregate(List<String> aggregateNames) {
}

public Map<String, Map<String, String>> jsonPrint(AttributeEncoder encoder,
List<String> aggregateNames) {
List<String> aggregateNames) {
return new HashMap<String, Map<String, String>>() {{
put("matcher", prettyPrintMatch(encoder));
put("metric", prettyPrintMetric());
Expand All @@ -72,22 +101,23 @@ private String removeBrackets(String str) {
}

public String toString() {
return "a="+matcher.toString()+":ag="+Arrays.toString(aggregates)+":mt="+Arrays.toString(metrics);
return "a=" + matcher.toString() + ":ag=" + Arrays.toString(aggregates) + ":mt=" + Arrays
.toString(metrics);
}

public String prettyPrint(
AttributeEncoder encoder,
List<String> aggregateNames
AttributeEncoder encoder,
List<String> aggregateNames
) {
String metricString = removeBrackets(prettyPrintMetric().toString());
String matchString = removeBrackets(prettyPrintMatch(encoder).toString());
String aggregateString = removeBrackets(prettyPrintAggregate(aggregateNames).toString());

return String.format(
"%s: %s\n%s: %s\n%s: %s\n",
"metrics", metricString,
"matches", matchString,
"aggregates", aggregateString
"%s: %s\n%s: %s\n%s: %s\n",
"metrics", metricString,
"matches", matchString,
"aggregates", aggregateString
);
}
}
Loading

0 comments on commit 0bb2c45

Please sign in to comment.