Skip to content

Commit

Permalink
OPENNLP-1623: Add more test cases for Coref component (#201)
Browse files Browse the repository at this point in the history
* OPENNLP-1623 Add more test cases for Coref component
- adds a ton of new test classes
- adjusts TrainSimilarityModel interface to return the trained model
- fixes several undetected class cast bugs due to '@SuppressWarnings("unchecked")' cover-up
- makes CorefTrainerTest (somehow) pass
- makes classes named *Enum* be actual enums
- uses arbitrary values in example sgml, thx @rzo1
  • Loading branch information
mawiesne authored Jan 5, 2025
1 parent edc23ba commit a380b05
Show file tree
Hide file tree
Showing 58 changed files with 15,900 additions and 601 deletions.
1 change: 1 addition & 0 deletions opennlp-coref/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
<groupId>net.sf.extjwnl</groupId>
<artifactId>extjwnl-data-wn31</artifactId>
<version>1.2</version>
<scope>runtime</scope>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
import opennlp.tools.cmdline.AbstractConverterTool;
import opennlp.tools.coref.CorefSample;
import opennlp.tools.coref.CorefSampleStreamFactory;
import opennlp.tools.postag.POSSample;

/**
* Tool to convert multiple data formats into native OpenNLP Coref training format.
*
* @see AbstractConverterTool
* @see CorefSample
* @see CorefSampleStreamFactory
*/
public class CoreferenceConverterTool extends AbstractConverterTool<CorefSample, CorefSampleStreamFactory.Parameters> {

public CoreferenceConverterTool() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ private void show(Parse p) {
}
}
}


@Override
public String getShortDescription() {
return "learnable noun phrase coreferencer";
return "Learnable Noun Phrase Coreferencer";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,44 @@
import java.io.IOException;

import opennlp.tools.cmdline.AbstractTrainerTool;
import opennlp.tools.cmdline.CmdLineUtil;
import opennlp.tools.cmdline.TerminateToolException;
import opennlp.tools.cmdline.coref.CoreferencerTrainerTool.TrainerToolParams;
import opennlp.tools.cmdline.params.TrainingToolParams;
import opennlp.tools.coref.CorefSample;
import opennlp.tools.coref.CorefTrainer;
import opennlp.tools.util.model.ModelUtil;

public class CoreferencerTrainerTool extends AbstractTrainerTool<CorefSample, TrainerToolParams> {

public interface TrainerToolParams extends TrainingParams, TrainingToolParams {
interface TrainerToolParams extends TrainingParams, TrainingToolParams {
}

public CoreferencerTrainerTool() {
super(CorefSample.class, TrainerToolParams.class);
}

@Override
public String getShortDescription() {
return "Trainer for a Learnable Noun Phrase Coreferencer";
}

@Override
public void run(String format, String[] args) {

super.run(format, args);

mlParams = CmdLineUtil.loadTrainingParameters(params.getParams(), false);

if (mlParams == null) {
mlParams = ModelUtil.createDefaultTrainingParameters();
}

try {
CorefTrainer.train(params.getModel().toString(), sampleStream, true, true);
} catch (IOException e) {
throw new TerminateToolException(-1, "IO error while reading training data or indexing data: " +
e.getMessage(), e);
}
}

public static void main(String[] args) {
new CoreferencerTrainerTool();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import opennlp.tools.cmdline.params.BasicTrainingParams;

/**
* TrainingParameters for Name Finder.
* TrainingParameters for Co-Referencer
* <p>
* Note: Do not use this class, internal use only!
*/
Expand Down
48 changes: 30 additions & 18 deletions opennlp-coref/src/main/java/opennlp/tools/coref/CorefSample.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,26 @@
import java.util.ArrayList;
import java.util.List;

import opennlp.tools.commons.Sample;
import opennlp.tools.parser.Parse;

public class CorefSample {
/**
* Encapsulates {@link Parse parses} that originate from parsing operation on text.
*/
public class CorefSample implements Sample {

private final List<Parse> parses;

public CorefSample(List<Parse> parses) {
this.parses = parses;
}


/**
* Converts the encapsulated {@link Parse parses} into
* {@link opennlp.tools.coref.mention.Parse Coref-related parse} instances.
*
* @return A list of converted Coref-related parses.
*/
public List<opennlp.tools.coref.mention.Parse> getParses() {

List<opennlp.tools.coref.mention.Parse> corefParses = new ArrayList<>();
Expand All @@ -41,22 +51,13 @@ public List<opennlp.tools.coref.mention.Parse> getParses() {

return corefParses;
}

@Override
public String toString() {

StringBuffer sb = new StringBuffer();

for (Parse parse : parses) {
parse.show(sb);
sb.append('\n');
}

sb.append('\n');

return sb.toString();
}


/**
* Parses a given text sample into a {@link CorefSample}.
*
* @param corefSampleString A non-empty text fragment which can have multiple lines.
* @return A valid {@link CorefSample} instance.
*/
public static CorefSample parse(String corefSampleString) {

List<Parse> parses = new ArrayList<>();
Expand All @@ -67,4 +68,15 @@ public static CorefSample parse(String corefSampleString) {

return new CorefSample(parses);
}

@Override
public String toString() {
StringBuffer sb = new StringBuffer();
for (Parse parse : parses) {
parse.show(sb);
sb.append('\n');
}
sb.append('\n');
return sb.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ public CorefSample read() throws IOException {
String document = samples.read();
if (document != null) {
return CorefSample.parse(document);
}
else {
} else {
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.io.IOException;

import opennlp.tools.cmdline.ArgumentParser;
import opennlp.tools.cmdline.CmdLineUtil;
import opennlp.tools.cmdline.ObjectStreamFactory;
import opennlp.tools.cmdline.StreamFactoryRegistry;
import opennlp.tools.cmdline.params.BasicFormatParams;
Expand Down Expand Up @@ -55,16 +54,11 @@ public static void registerFactory() {
public ObjectStream<CorefSample> create(String[] args) {
Parameters params = ArgumentParser.parse(args, Parameters.class);

CmdLineUtil.checkInputFile("Data", params.getData());
final MarkableFileInputStreamFactory factory;
try {
factory = new MarkableFileInputStreamFactory(params.getData());
final MarkableFileInputStreamFactory factory = new MarkableFileInputStreamFactory(params.getData());
return new CorefSampleDataStream(new ParagraphStream(new PlainTextByLineStream(factory, params.getEncoding())));
} catch (FileNotFoundException e) {
throw new RuntimeException("Error finding specified input file!", e);
}
try (ObjectStream<String> lineStream = new ParagraphStream(new PlainTextByLineStream(
factory, params.getEncoding()))) {
return new CorefSampleDataStream(lineStream);
} catch (IOException e) {
throw new RuntimeException("Error loading Coref samples from input data!", e);
}
Expand Down
67 changes: 30 additions & 37 deletions opennlp-coref/src/main/java/opennlp/tools/coref/CorefTrainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,30 @@
import java.util.List;
import java.util.Stack;

import opennlp.tools.commons.Trainer;
import opennlp.tools.coref.linker.DefaultLinker;
import opennlp.tools.coref.linker.Linker;
import opennlp.tools.coref.linker.LinkerMode;
import opennlp.tools.coref.linker.TreebankLinker;
import opennlp.tools.coref.mention.Mention;
import opennlp.tools.coref.mention.MentionContext;
import opennlp.tools.coref.mention.MentionFinder;
import opennlp.tools.coref.resolver.MaxentResolver;
import opennlp.tools.coref.sim.GenderModel;
import opennlp.tools.coref.sim.NumberModel;
import opennlp.tools.coref.sim.SimilarityModel;
import opennlp.tools.coref.sim.TrainSimilarityModel;
import opennlp.tools.coref.sim.TrainModel;
import opennlp.tools.ml.AbstractTrainer;
import opennlp.tools.parser.Parse;
import opennlp.tools.util.ObjectStream;

public class CorefTrainer {
/**
* A {@link Trainer} implementation for co-reference resolution models.
*
* @see Trainer
* @see CorefModel
* @see CorefSample
*/
public class CorefTrainer extends AbstractTrainer implements Trainer {

private static boolean containsToken(String token, Parse p) {
for (Parse node : p.getTagNodes()) {
Expand All @@ -49,52 +57,41 @@ private static boolean containsToken(String token, Parse p) {
}

private static Mention[] getMentions(CorefSample sample, MentionFinder mentionFinder) {

List<Mention> mentions = new ArrayList<>();

for (opennlp.tools.coref.mention.Parse corefParse : sample.getParses()) {

Parse p = ((DefaultParse) corefParse).getParse();

Mention[] extents = mentionFinder.getMentions(corefParse);

for (Mention extent : extents) {

if (extent.getParse() == null) {

Stack<Parse> nodes = new Stack<>();
nodes.add(p);

while (!nodes.isEmpty()) {

Parse node = nodes.pop();

if (node.getSpan().equals(extent.getSpan()) && node.getType().startsWith("NML")) {
DefaultParse corefParseNode = new DefaultParse(node, corefParse.getSentenceNumber());
extent.setParse(corefParseNode);
extent.setId(corefParseNode.getEntityId());
break;
}

nodes.addAll(Arrays.asList(node.getChildren()));
}
}
}

mentions.addAll(Arrays.asList(extents));
}

return mentions.toArray(new Mention[0]);
}

public static void train(String modelDirectory, ObjectStream<CorefSample> samples,
boolean useTreebank, boolean useDiscourseModel) throws IOException {

TrainSimilarityModel simTrain = SimilarityModel.trainModel(modelDirectory + "/coref/sim");
TrainSimilarityModel genTrain = GenderModel.trainModel(modelDirectory + "/coref/gen");
TrainSimilarityModel numTrain = NumberModel.trainModel(modelDirectory + "/coref/num");

useTreebank = true;
TrainModel<SimilarityModel> simTrain =
SimilarityModel.trainModel(modelDirectory + "/coref/sim");
TrainModel<GenderModel> genTrain =
GenderModel.trainModel(modelDirectory + "/coref/gen");
TrainModel<NumberModel> numTrain =
NumberModel.trainModel(modelDirectory + "/coref/num");

Linker simLinker;

Expand All @@ -115,35 +112,31 @@ public static void train(String modelDirectory, ObjectStream<CorefSample> sample
genTrain.setExtents(extentContexts);
numTrain.setExtents(extentContexts);
}

simTrain.trainModel();
genTrain.trainModel();
numTrain.trainModel();

MaxentResolver.setSimilarityModel(SimilarityModel.testModel(modelDirectory + "/coref" + "/sim"));

// Done with similarity training

// Now train the linkers

final SimilarityModel simModel = simTrain.trainModel();
final GenderModel genderModel = genTrain.trainModel();
final NumberModel numberModel = numTrain.trainModel();

// Done with similarity training, now train the linkers

// Training data needs to be read in again and the stream must be reset
samples.reset();

// Now train linkers
// Now create linkers
Linker trainLinker;
if (useTreebank) {
trainLinker = new TreebankLinker(modelDirectory + "/coref/", LinkerMode.TRAIN, useDiscourseModel);
}
else {
trainLinker = new DefaultLinker(modelDirectory + "/coref/", LinkerMode.TRAIN, useDiscourseModel);
trainLinker = new TreebankLinker(modelDirectory, LinkerMode.TRAIN,
simModel, genderModel, numberModel, useDiscourseModel, -1);
} else {
trainLinker = new DefaultLinker(modelDirectory, LinkerMode.TRAIN,
simModel, genderModel, numberModel, useDiscourseModel, -1);
}

for (CorefSample sample = samples.read(); sample != null; sample = samples.read()) {

Mention[] mentions = getMentions(sample, trainLinker.getMentionFinder());
trainLinker.setEntities(mentions);
}

trainLinker.train();
}

}
Loading

0 comments on commit a380b05

Please sign in to comment.