Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,754 changes: 905 additions & 849 deletions src/com/jgaap/backend/API.java

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/com/jgaap/classifiers/BurrowsDelta.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,4 @@ public List<Pair<String, Double>> analyze(Document unknown) {
Collections.sort(results);
return results;
}
}
}
9 changes: 6 additions & 3 deletions src/com/jgaap/classifiers/KNearestNeighborDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ public class KNearestNeighborDriver extends NeighborAnalysisDriver {

private static final int DEFAULT_K = 5;
private static final String DEFAULT_TIE = "lastPicked";

public KNearestNeighborDriver() {
addParams("k", "K", "5", new String[] {"1","2","3","4","5","6","7","8","9","10"}, false);
addParams("k", "K: Number of Neighbors", "5", new String[] { "1", "2",
"3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13",
"14", "15", "16", "17", "18", "19", "20", "21", "22", "23",
"24", "25" }, false);
}

public String displayName() {
Expand Down Expand Up @@ -107,7 +110,7 @@ public List<Pair<String, Double>> analyze(Document unknown) throws AnalyzeExcept
}

List<Pair<String, Double>> results = ballot.getResults();
Comparator<Pair<String, Double>> compareByScore = (Pair<String, Double> r1, Pair<String, Double> r2) -> r2.getSecond().compareTo(r1.getSecond());
Comparator<Pair<String, Double>> compareByScore = (Pair<String, Double> r1, Pair<String, Double> r2) -> r1.getSecond().compareTo(r2.getSecond());
Collections.sort(results, compareByScore);

return results;
Expand Down
144 changes: 144 additions & 0 deletions src/com/jgaap/classifiers/LeaveOneOutKNearestNeighborDriver.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* JGAAP -- a graphical program for stylometric authorship attribution
* Copyright (C) 2009,2011 by Patrick Juola
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package com.jgaap.classifiers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;


import com.google.common.collect.ImmutableList;
import com.jgaap.generics.AnalyzeException;
import com.jgaap.generics.DistanceCalculationException;
import com.jgaap.generics.ValidationDriver;
import com.jgaap.util.Ballot;
import com.jgaap.util.Document;
import com.jgaap.util.EventMap;
import com.jgaap.util.Pair;

/*KNN LOOCV implementation by @Alejandro Jorge Napolitano Jawerbaum*/

public class LeaveOneOutKNearestNeighborDriver extends ValidationDriver {

private java.util.logging.Logger logger = java.util.logging.Logger.getLogger(LeaveOneOutKNearestNeighborDriver.class.getName());

private ImmutableList<Pair<Document, EventMap>> knowns;

private static final int DEFAULT_K = 5;
private static final String DEFAULT_TIE = "lastPicked";

public LeaveOneOutKNearestNeighborDriver() {
addParams("k", "K: Number of Neighbors", "5", new String[] { "1", "2",
"3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13",
"14", "15", "16", "17", "18", "19", "20", "21", "22", "23",
"24", "25" }, false);
}
@Override
public String displayName() {
return "Leave One Out K-Nearest Neighbor driver" + this.getDistanceName();
}
@Override
public String tooltipText() {
return " ";
}
@Override
public boolean showInGUI() {
return true;
}
@Override
public void train(List<Document> knowns){
ImmutableList.Builder<Pair<Document, EventMap>> builder = ImmutableList.builder();
for(Document known : knowns) {
builder.add(new Pair<Document, EventMap>(known, new EventMap(known)));
}
this.knowns = builder.build();
}

@Override
public List<Pair<String, Double>> analyze(Document unknown) throws AnalyzeException {

Ballot<String> ballot = new Ballot<String>();

int k = getParameter("k", DEFAULT_K);

String tieBreaker = getParameter("tieBreaker", DEFAULT_TIE);

List<Pair<String, Double>> rawResults = new ArrayList<Pair<String,Double>>();

for (int i = 0; i < knowns.size(); i++) {
if(!knowns.get(i).getFirst().equals(unknown)) {
double current;
try {
current = distance.distance(new EventMap(unknown), knowns.get(i).getSecond());
} catch (DistanceCalculationException e) {
throw new AnalyzeException("Distance "+distance.displayName()+" failed");
}
rawResults.add(new Pair<String, Double>(knowns.get(i).getFirst().getAuthor(), current, 2));
}
else
logger.info("Excluded document that's being tested.");
}
Collections.sort(rawResults);
for(int i = 0; i < Math.min(k, rawResults.size()); i++) {
Pair<String, Double> p = rawResults.get(i);
ballot.vote(p.getFirst(), (1 + Math.pow(2, (-1.0 * (i+1)))));
}

if(tieBreaker.equals("lastPicked")) {
ballot.setComparator(new LastPickedComparator());
}

List<Pair<String, Double>> results = ballot.getResults();
Comparator<Pair<String, Double>> compareByScore = (Pair<String, Double> r1, Pair<String, Double> r2) -> r2.getSecond().compareTo(r1.getSecond());
Collections.sort(results, compareByScore);

return results;
}

private static class LastPickedComparator implements Comparator<Pair<String, Double>>, Serializable {

private static final long serialVersionUID = 1L;

public int compare(Pair<String, Double> firstPair, Pair<String, Double> secondPair) {
double first = firstPair.getSecond();
double second = secondPair.getSecond();

// If the overall rank was not the same, then return these according to rank.
if((int)first != (int)second) {
return (int)first - (int)second;
}

// Otherwise, we want to move the decimal point right until we have an integer.
while(((int)first - first) > 0.0000001) {
first *= 2;
second *= 2;
}
// If first had fewer decimal places than second, this means the last first vote came BEFORE the last second vote.
if(((int)second -second) > 0.0000001) {
return 1;
}
// Otherwise, the last second vote came before the last first vote.
else {
return -1;
}
}
}
}
8 changes: 5 additions & 3 deletions src/com/jgaap/classifiers/LeaveOneOutNoDistanceDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ public List<Pair<String, Double>> analyze(Document fakeUnknown) throws AnalyzeEx
// document. We call this document a fake unknown because it is actually known,
// but we want to pretend that it isn't.
List<Document> knownsTemp = new ArrayList<>();
for(Document known : knownDocuments)
if(known != fakeUnknown)
for(Document known : knownDocuments) {
if(!known.equals(fakeUnknown))
knownsTemp.add(known);

}
// Set the analysisDriver's parameters.
// Pass the temporary known list and the fake unknown to the analysis driver that this
// driver depends on, and return the result.
analysisDriver.setParamGUI(getParamGUI());
analysisDriver.train(knownsTemp);
return analysisDriver.analyze(fakeUnknown);
}
Expand Down
35 changes: 35 additions & 0 deletions src/com/jgaap/classifiers/WEKALogisticRegression.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.jgaap.classifiers;

import java.util.List;

import com.jgaap.generics.AnalyzeException;
import com.jgaap.generics.WEKAAnalysisDriver;
import com.jgaap.util.Document;

import weka.classifiers.Classifier;

public class WEKALogisticRegression extends WEKAAnalysisDriver {
@Override
public String displayName() {
return "WEKA Logistic Regression";
}

@Override
public String tooltipText() {
return "Multinomial logistic regression, Courtesy of WEKA";
}

@Override
public boolean showInGUI() {
return true;
}

public Classifier getClassifier() {
return (Classifier)(new weka.classifiers.functions.Logistic());
}
public void testRequirements(List<Document> knownList) throws AnalyzeException{
//No requirements
return;
}

}
142 changes: 142 additions & 0 deletions src/com/jgaap/classifiers/weightedVoting.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package com.jgaap.classifiers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import com.jgaap.backend.AnalysisDrivers;
import com.jgaap.backend.DistanceFunctions;
import com.jgaap.generics.AnalysisDriver;
import com.jgaap.generics.AnalyzeException;
import com.jgaap.generics.DistanceFunction;
import com.jgaap.generics.NeighborAnalysisDriver;
import com.jgaap.generics.ValidationDriver;
import com.jgaap.util.Document;
import com.jgaap.util.Pair;
import com.jgaap.util.WeightingMethod;
/** @author Alejandro J Napolitano Jawerbaum
See tooltipText for a short description.
* weightedVoting weights algorithms' votes (prediction) according to a weighting algorithm. "None" is an option.
* Using sets instead of arraylists to user-proof it against having the same algorithm vote multiple times.
*/
public class weightedVoting extends AnalysisDriver {
public Set<AnalysisDriver> classifiers = new HashSet<AnalysisDriver>();
private static Set<Pair<AnalysisDriver, Double>> weightedClassifiers = new HashSet<Pair<AnalysisDriver, Double>>();
private static Set<Pair<AnalysisDriver, Double>> weights = new HashSet<Pair<AnalysisDriver, Double>>();
private static Set<String> authors = new HashSet<String>();
private static Logger logger = Logger.getLogger(weightedVoting.class);
private static List<Document> knowns = new ArrayList<Document>();

public weightedVoting() {
addParams("Classifiers", "Classifiers to be put to a vote.","Comma-separated list. Add | before parameters.", new String[] {""}, true); //TODO: Get all classifiers and add them to the array, then call each of them
addParams("Distances", "Distance metrics for distance dependent Analysis Drivers","Comma-separated list", new String[] {""}, true);
addParams("WeightingMethod", "Way to weight the classifiers.", "cross-validation", new String[]{"cross-validation", "accuracyOverSum", "none"}, false);
addParams("Cutoff", "Minimum cross-validation score to consider an algorithm's vote.", "75", new String[]{"0", "10", "20","30","40","45","50","55","60","65","70","75","80","85","90","95", "100"}, true);
addParams("VotingMethod", "Voting Method.", "sum", new String[] {"sum", "sum/count"}, false);
addParams("AuthorsForCrossval", "Comma separated list of Authors to cross-validate. Empty = All.", "", new String[] {}, true);
}

@Override
public String displayName() {
return "Weighted Voting";
}

@Override
public String tooltipText() {
return "Takes in a list of analysis drivers, and put them to a vote on each unknown document. Warning: We recommend including independent classifiers only.";
}

@Override
public boolean showInGUI() {
return true;
}


@Override
public void train(List<Document> knownDocuments) throws AnalyzeException {
for(Document doc : knownDocuments)
authors.add(doc.getAuthor());
knowns = knownDocuments;
Set<AnalysisDriver> clsfr = new HashSet<AnalysisDriver>();
for(String s : getParameter("Classifiers").split(",")) {
try {
AnalysisDriver classifier = AnalysisDrivers.getAnalysisDriver(s.trim());
if(classifier instanceof NeighborAnalysisDriver) {
NeighborAnalysisDriver classif = (NeighborAnalysisDriver)AnalysisDrivers.getAnalysisDriver(s);;
String[] distances = getParameter("Distances").split(",");
for(String distance : distances) {
DistanceFunction dist = DistanceFunctions.getDistanceFunction(distance);
classif.setDistance(dist);
clsfr.add(classif);
}
}
else if(!(classifier instanceof LeaveOneOutNoDistanceDriver) && !(classifier instanceof ValidationDriver) && !(classifier instanceof weightedVoting))
clsfr.add(classifier);
else
logger.info("Excluded cross-validation driver. Or worse, a weighted voting inception.");
} catch (Exception e) {
e.printStackTrace();
}
}
classifiers = clsfr;
weights = WeightingMethod.weight(classifiers, knownDocuments, getParameter("WeightingMethod"), getParameter("AuthorsForCrossval"));
Set<Pair<AnalysisDriver, Double>> weighted = new HashSet<Pair<AnalysisDriver,Double>>();
if(!getParameter("Cutoff").equals("0")) {
for(Pair<AnalysisDriver, Double> weight : weights)
if(weight.getSecond()>=(Double.parseDouble(getParameter("Cutoff"))/100))
weighted.add(weight);
weightedClassifiers = weighted;
}
}
/**
* Analyzes the unknown document and tallies the weighted votes.
* @param Document unknownDocument. Pass in the document to be analyzed.
* */
public Map<String, Double> vote(Document unknownDocument) throws AnalyzeException {
List<Pair<String, Double>> authorVote = new ArrayList<Pair<String,Double>>();

for(Pair<AnalysisDriver, Double> weightedClassifier : weightedClassifiers) {
List<Pair<String, Double>> results = weightedClassifier.getFirst().analyze(unknownDocument);
logger.info(weightedClassifier.getFirst().displayName()+ ". weight = " + weightedClassifier.getSecond() + ". Voted for " + results.get(0).getFirst() + " for document " + unknownDocument.getTitle());
authorVote.add(new Pair<String,Double>(results.get(0).getFirst(), weightedClassifier.getSecond()));
}
//We should check the results for ties, and let the score be 0 for all authors if that is the case.
Map<String, Double> authorVoteSumMap = new HashMap<String, Double>();
for (String author : authors) {
double totalVote = 0.0;
for (Pair<String, Double> vote : authorVote) {
if (vote.getFirst().contains(author)) {
totalVote += vote.getSecond();
}
}
if(!authorVoteSumMap.containsKey(author))
authorVoteSumMap.put(author, totalVote);

}
logger.info(authorVoteSumMap);
return authorVoteSumMap;
}

@Override
public List<Pair<String, Double>> analyze(Document unknownDocument) throws AnalyzeException {
for(Pair<AnalysisDriver, Double> weightedClassifier : weightedClassifiers) {
logger.info("Training " + weightedClassifier.getFirst().displayName() + " for analysis");
weightedClassifier.getFirst().train(knowns);
logger.info("Finished training " + weightedClassifier.getFirst().displayName() + " for analysis");
}
Map<String, Double> authorVoteSumMap = vote(unknownDocument);
Comparator<Pair<String, Double>> compareByScore = (Pair<String, Double> r1, Pair<String, Double> r2) -> r2.getSecond().compareTo(r1.getSecond());
List<Pair<String,Double>> authorVoteSum = new ArrayList<Pair<String,Double>>();
for(String author : authors)
authorVoteSum.add(new Pair<String,Double>(author,authorVoteSumMap.get(author)));

Collections.sort(authorVoteSum, compareByScore);
//Collections.reverse(authorVoteSum);
return authorVoteSum;
}
}
Loading