Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cascades to APLinear #245

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5,415 changes: 5,415 additions & 0 deletions Moments Cube Creation.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import edu.stanford.futuredata.macrobase.analysis.classify.PredicateCubeClassifier;
import edu.stanford.futuredata.macrobase.analysis.classify.QuantileClassifier;
import edu.stanford.futuredata.macrobase.analysis.classify.RawClassifier;
import edu.stanford.futuredata.macrobase.analysis.summary.aplinear.APLExplanation;
import edu.stanford.futuredata.macrobase.analysis.summary.aplinear.APLMeanSummarizer;
import edu.stanford.futuredata.macrobase.analysis.summary.aplinear.APLOutlierSummarizer;
import edu.stanford.futuredata.macrobase.analysis.summary.aplinear.APLSummarizer;
import edu.stanford.futuredata.macrobase.analysis.summary.aplinear.*;
import edu.stanford.futuredata.macrobase.datamodel.DataFrame;
import edu.stanford.futuredata.macrobase.datamodel.Schema;
import edu.stanford.futuredata.macrobase.ingest.CSVDataFrameWriter;
Expand Down Expand Up @@ -50,7 +47,10 @@ public class CubePipeline implements Pipeline {
private boolean includeLo;
private Optional<String> meanColumn;
private Optional<String> stdColumn;
private Optional<String> minColumn;
private Optional<String> maxColumn;
private LinkedHashMap<String, Double> quantileColumns;
private List<String> momentColumns;

// Explanation
private List<String> attributes;
Expand Down Expand Up @@ -88,7 +88,10 @@ public CubePipeline(PipelineConfig conf) {
includeLo = conf.get("includeLo", true);
meanColumn = Optional.ofNullable(conf.get("meanColumn"));
stdColumn = Optional.ofNullable(conf.get("stdColumn"));
minColumn = Optional.ofNullable(conf.get("minColumn"));
maxColumn = Optional.ofNullable(conf.get("maxColumn"));
quantileColumns = conf.get("quantileColumns", new LinkedHashMap<String, Double>());
momentColumns = conf.get("momentColumns", new ArrayList<String>());

attributes = conf.get("attributes");
minSupport = conf.get("minSupport", 3.0);
Expand Down Expand Up @@ -173,6 +176,18 @@ private Map<String, Schema.ColType> getColTypes() throws MacrobaseException {
}
return colTypes;
}
case "moment": {
for (String col : momentColumns) {
colTypes.put(col, Schema.ColType.DOUBLE);
}
colTypes.put(minColumn
.orElseThrow(() -> new MacrobaseException("min column not present in config")),
Schema.ColType.DOUBLE);
colTypes.put(maxColumn
.orElseThrow(() -> new MacrobaseException("max column not present in config")),
Schema.ColType.DOUBLE);
return colTypes;
}
case "raw": {
colTypes.put(meanColumn.orElseThrow(
() -> new MacrobaseException("mean column not present in config")),
Expand Down Expand Up @@ -216,6 +231,12 @@ private CubeClassifier getClassifier() throws MacrobaseException {
() -> new MacrobaseException("metric column not present in config")),
predicateStr, cutoff);
}
case "moment": {
return new RawClassifier(
countColumn,
null
);
}

case "meanshift":
case "raw": {
Expand Down Expand Up @@ -244,6 +265,19 @@ private APLSummarizer getSummarizer(CubeClassifier classifier) throws Exception
summarizer.setMinStdDev(minRatioMetric);
return summarizer;
}
// case "moment": {
// APLMomentSummarizer summarizer = new APLMomentSummarizer();
// summarizer.setMinColumn(minColumn.orElseThrow(
// () -> new MacrobaseException("min column not present in config")));
// summarizer.setMaxColumn(maxColumn.orElseThrow(
// () -> new MacrobaseException("max column not present in config")));
// summarizer.setMomentColumns(momentColumns);
// summarizer.setAttributes(attributes);
// summarizer.setMinSupport(minSupport);
// summarizer.setMinRatioMetric(minRatioMetric);
// summarizer.setPercentile(cutoff);
// return summarizer;
// }
default: {
APLOutlierSummarizer summarizer = new APLOutlierSummarizer();
summarizer.setOutlierColumn(classifier.getOutputColumnName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
Expand All @@ -18,6 +19,7 @@
public class APLOutlierSummarizer extends APLSummarizer {
private Logger log = LoggerFactory.getLogger("APLOutlierSummarizer");
private String countColumn = null;
private boolean onlyUseSupport = false;

@Override
public List<String> getAggregateNames() {
Expand All @@ -42,15 +44,21 @@ public List<QualityMetric> getQualityMetricList() {
qualityMetricList.add(
new SupportMetric(0)
);
qualityMetricList.add(
new GlobalRatioMetric(0, 1)
);
if (!onlyUseSupport) {
qualityMetricList.add(
new GlobalRatioMetric(0, 1)
);
}
return qualityMetricList;
}

@Override
public List<Double> getThresholds() {
return Arrays.asList(minOutlierSupport, minRatioMetric);
if (onlyUseSupport) {
return Collections.singletonList(minOutlierSupport);
} else {
return Arrays.asList(minOutlierSupport, minRatioMetric);
}
}

@Override
Expand All @@ -72,4 +80,5 @@ public void setCountColumn(String countColumn) {
public double getMinRatioMetric() {
return minRatioMetric;
}
public void onlyUseSupport(boolean onlyUseSupport) { this.onlyUseSupport = onlyUseSupport; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
* Generic summarizer superclass that can be customized with
Expand All @@ -18,11 +20,12 @@
*/
public abstract class APLSummarizer extends BatchSummarizer {
Logger log = LoggerFactory.getLogger("APLSummarizer");
AttributeEncoder encoder;
APLExplanation explanation;
APrioriLinear aplKernel;
List<QualityMetric> qualityMetricList;
List<Double> thresholds;
protected AttributeEncoder encoder;
protected APLExplanation explanation;
protected APrioriLinear aplKernel;
protected boolean doContainment = true;
public List<QualityMetric> qualityMetricList;
protected List<Double> thresholds;

protected long numEvents = 0;
protected long numOutliers = 0;
Expand Down Expand Up @@ -66,10 +69,12 @@ public void process(DataFrame input) throws Exception {
qualityMetricList,
thresholds
);
aplKernel.setDoContainment(doContainment);

double[][] aggregateColumns = getAggregateColumns(input);
List<String> aggregateNames = getAggregateNames();
List<APLExplanationResult> aplResults = aplKernel.explain(encoded, aggregateColumns);
Map<String, int[]> aggregationOps = getAggregationOps();
List<APLExplanationResult> aplResults = aplKernel.explain(encoded, aggregateColumns, aggregationOps);
numOutliers = (long)getNumberOutliers(aggregateColumns);

explanation = new APLExplanation(
Expand All @@ -87,4 +92,9 @@ public APLExplanation getResults() {
return explanation;
}

public Map<String, int[]> getAggregationOps() {
return null;
}

public void setDoContainment(boolean doContainment) { this.doContainment = doContainment; }
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package edu.stanford.futuredata.macrobase.analysis.summary.aplinear;

import edu.stanford.futuredata.macrobase.analysis.summary.aplinear.metrics.AggregationOp;
import edu.stanford.futuredata.macrobase.analysis.summary.aplinear.metrics.QualityMetric;
import edu.stanford.futuredata.macrobase.analysis.summary.apriori.APrioriSummarizer;
import edu.stanford.futuredata.macrobase.analysis.summary.apriori.IntSet;
Expand All @@ -8,8 +9,6 @@
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;

/**
Expand All @@ -24,6 +23,11 @@ public class APrioriLinear {
// **Parameters**
private QualityMetric[] qualityMetrics;
private double[] thresholds;
private boolean doContainment = true;

public long mergeTime = 0;
public long queryTime = 0;
private long start;

// **Cached values**

Expand All @@ -50,23 +54,36 @@ public APrioriLinear(
public List<APLExplanationResult> explain(
final List<int[]> attributes,
double[][] aggregateColumns
) {
return explain(attributes, aggregateColumns, null);
}

public List<APLExplanationResult> explain(
final List<int[]> attributes,
double[][] aggregateColumns,
AggregationOp[] aggregationOps
) {
final int numAggregates = aggregateColumns.length;
final int numRows = aggregateColumns[0].length;

// Quality metrics are initialized with global aggregates to
// allow them to determine the appropriate relative thresholds
double[] globalAggregates = new double[numAggregates];
start = System.nanoTime();
for (int j = 0; j < numAggregates; j++) {
globalAggregates[j] = 0;
AggregationOp curOp = aggregationOps[j];
globalAggregates[j] = curOp.initValue();
double[] curColumn = aggregateColumns[j];
for (int i = 0; i < numRows; i++) {
globalAggregates[j] += curColumn[i];
globalAggregates[j] = curOp.combine(globalAggregates[j], curColumn[i]);
}
}
mergeTime += System.nanoTime() - start;
start = System.nanoTime();
for (QualityMetric q : qualityMetrics) {
q.initialize(globalAggregates);
}
queryTime += System.nanoTime() - start;

// Row store for more convenient access
final double[][] aRows = new double[numRows][numAggregates];
Expand All @@ -89,6 +106,7 @@ public List<APLExplanationResult> explain(
threadSetAggregates.add(new HashMap<>());
}
final CountDownLatch doneSignal = new CountDownLatch(numThreads);
start = System.nanoTime();
for (int threadNum = 0; threadNum < numThreads; threadNum++) {
final int startIndex = (numRows * threadNum) / numThreads;
final int endIndex = (numRows * (threadNum + 1)) / numThreads;
Expand All @@ -108,10 +126,20 @@ public List<APLExplanationResult> explain(
double[] candidateVal = thisThreadSetAggregates.get(curCandidate);
if (candidateVal == null) {
thisThreadSetAggregates.put(curCandidate, Arrays.copyOf(aRows[i], numAggregates));
} else {
} else if (aggregationOps == null) {
for (int a = 0; a < numAggregates; a++) {
candidateVal[a] += aRows[i][a];
}
} else {
for (int a : aggregationOps.getOrDefault("add", new int[0])) {
candidateVal[a] += aRows[i][a];
}
for (int a : aggregationOps.getOrDefault("min", new int[0])) {
candidateVal[a] = Math.min(candidateVal[a], aRows[i][a]);
}
for (int a : aggregationOps.getOrDefault("max", new int[0])) {
candidateVal[a] = Math.max(candidateVal[a], aRows[i][a]);
}
}
}
}
Expand All @@ -134,38 +162,49 @@ public List<APLExplanationResult> explain(
double[] candidateVal = setAggregates.get(curCandidateKey);
if (candidateVal == null) {
setAggregates.put(curCandidateKey, Arrays.copyOf(curCandidateValue, numAggregates));
} else {
} else if (aggregationOps == null) {
for (int a = 0; a < numAggregates; a++) {
candidateVal[a] += curCandidateValue[a];
}
} else {
for (int a : aggregationOps.getOrDefault("add", new int[0])) {
candidateVal[a] += curCandidateValue[a];
}
for (int a : aggregationOps.getOrDefault("min", new int[0])) {
candidateVal[a] = Math.min(candidateVal[a], curCandidateValue[a]);
}
for (int a : aggregationOps.getOrDefault("max", new int[0])) {
candidateVal[a] = Math.max(candidateVal[a], curCandidateValue[a]);
}
}
}
}
mergeTime += System.nanoTime() - start;

HashSet<IntSet> curOrderNext = new HashSet<>();
HashSet<IntSet> curOrderSaved = new HashSet<>();
int pruned = 0;
for (IntSet curCandidate: setAggregates.keySet()) {
double[] curAggregates = setAggregates.get(curCandidate);
boolean canPassThreshold = true;
boolean isPastThreshold = true;
QualityMetric.Action action = QualityMetric.Action.KEEP;
start = System.nanoTime();
for (int i = 0; i < qualityMetrics.length; i++) {
QualityMetric q = qualityMetrics[i];
double t = thresholds[i];
canPassThreshold &= q.maxSubgroupValue(curAggregates) >= t;
isPastThreshold &= q.value(curAggregates) >= t;
action = QualityMetric.Action.combine(action, q.getAction(curAggregates, t));
}
if (canPassThreshold) {
queryTime += System.nanoTime() - start;
if (action == QualityMetric.Action.KEEP) {
// if a set is already past the threshold on all metrics,
// save it and no need for further exploration
if (isPastThreshold) {
curOrderSaved.add(curCandidate);
}
else {
// otherwise if a set still has potentially good subsets,
// save it for further examination
// save it and no need for further exploration if we do containment
curOrderSaved.add(curCandidate);
if (!doContainment) {
curOrderNext.add(curCandidate);
}
} else if (action == QualityMetric.Action.NEXT) {
// otherwise if a set still has potentially good subsets,
// save it for further examination
curOrderNext.add(curCandidate);
} else {
pruned++;
}
Expand Down Expand Up @@ -269,4 +308,6 @@ private ArrayList<IntSet> getCandidates(
}
return candidates;
}

public void setDoContainment(boolean doContainment) { this.doContainment = doContainment; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package edu.stanford.futuredata.macrobase.analysis.summary.aplinear.metrics;

import edu.stanford.futuredata.macrobase.util.MacrobaseInternalError;

public enum AggregationOp {
SUM, MIN, MAX;

public double combine(double a, double b) {
switch(this) {
case SUM: {
return a+b;
}
case MIN: {
return a < b ? a : b;
}
case MAX: {
return a > b ? a : b;
}
default: {
throw new MacrobaseInternalError("Invalid Aggregation Op");
}
}
}

public double initValue() {
switch(this) {
case SUM: {
return 0;
}
case MIN: {
return Double.MAX_VALUE;
}
case MAX: {
return -Double.MAX_VALUE;
}
default: {
throw new MacrobaseInternalError("Invalid Aggregation Op");
}
}
}
}
Loading