Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainable relation extractor #364

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## Unreleased

### Added

- New trainable `eds.relation_detector_ffn` component to detect relations between entities. These relations are stored in each entity: `head._.rel[relation_label] = [tail1, tail2, ...]`.

# v0.15.0 (2024-12-13)

### Added
Expand Down
6 changes: 3 additions & 3 deletions edsnlp/core/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ def shuffle(
# same program twice, the shuffling should be the same in both cases.
# This is not garanteed by just creating random.Random() which does not
# account
seed = seed if seed is not None else random.getrandbits(32)
if shuffle_reader:
if shuffle_reader not in self.reader.emitted_sentinels:
raise ValueError(f"Cannot shuffle by {shuffle_reader}")
Expand All @@ -807,13 +806,14 @@ def shuffle(
config=stream.config,
)
stream.reader.shuffle = shuffle_reader
stream.reader.rng = random.Random(seed)
if seed is not None:
stream.reader.rng = random.Random(seed)
if any(not op.elementwise for op in self.ops) or not shuffle_reader:
stream = stream.map_batches(
pipe=shuffle,
batch_size=batch_size,
batch_by=batch_by,
kwargs={"rng": random.Random(seed)},
kwargs={"rng": random.Random(seed)} if seed is not None else {},
)
stream.validate_ops(ops=stream.ops, update=False)
return stream
Expand Down
11 changes: 11 additions & 0 deletions edsnlp/data/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def __call__(self, obj, tokenizer=None):
doc = tok(obj["text"] or "")
doc._.note_id = obj.get("doc_id", obj.get(FILENAME))

entities = {}
spans = []

for dst in (
Expand Down Expand Up @@ -303,6 +304,7 @@ def __call__(self, obj, tokenizer=None):
)
span._.set(new_name, value)

entities.setdefault(ent["entity_id"], []).append(span)
spans.append(span)

set_spans(doc, spans, span_setter=self.span_setter)
Expand All @@ -311,6 +313,15 @@ def __call__(self, obj, tokenizer=None):
if span._.get(attr) is None:
span._.set(attr, value)

for relation in obj.get("relations", []):
relation_label = relation["relation_label"]
from_entity_id = relation["from_entity_id"]
to_entity_id = relation["to_entity_id"]

for head in entities[from_entity_id]:
for tail in entities[to_entity_id]:
head._.rel.setdefault(relation_label, set()).add(tail)

return doc


Expand Down
22 changes: 11 additions & 11 deletions edsnlp/data/standoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,17 @@ def dump_standoff_file(
attribute_idx += 1

# fmt: off
# if "relations" in doc:
# for i, relation in enumerate(doc["relations"]):
# entity_from = entities_ids[relation["from_entity_id"]]
# entity_to = entities_ids[relation["to_entity_id"]]
# print(
# "R{}\t{} Arg1:{} Arg2:{}\t".format(
# i + 1, str(relation["label"]), entity_from,
# entity_to
# ),
# file=f,
# )
if "relations" in doc:
for i, relation in enumerate(doc["relations"]):
entity_from = entities_ids[relation["from_entity_id"]]
entity_to = entities_ids[relation["to_entity_id"]]
print(
"R{}\t{} Arg1:{} Arg2:{}\t".format(
i + 1, str(relation["label"]), entity_from,
entity_to
),
file=f,
)
# fmt: on


Expand Down
5 changes: 4 additions & 1 deletion edsnlp/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import date, datetime

from dateutil.parser import parse as parse_date
from spacy.tokens import Doc
from spacy.tokens import Doc, Span

if not Doc.has_extension("note_id"):
Doc.set_extension("note_id", default=None)
Expand Down Expand Up @@ -43,3 +43,6 @@ def get_note_datetime(doc):

if not Doc.has_extension("birth_datetime"):
Doc.set_extension("birth_datetime", default=None)

if not Span.has_extension("rel"):
Span.set_extension("rel", default={})
122 changes: 122 additions & 0 deletions edsnlp/metrics/relations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from collections import defaultdict
from itertools import product
from typing import Any, Optional

from edsnlp import registry
from edsnlp.metrics import Examples, make_examples, prf
from edsnlp.utils.span_getters import RelationCandidateGetter, get_spans
from edsnlp.utils.typing import AsList


def relations_scorer(
examples: Examples,
candidate_getter: AsList[RelationCandidateGetter],
micro_key: str = "micro",
filter_expr: Optional[str] = None,
):
"""
Scores the attributes predictions between a list of gold and predicted spans.

Parameters
----------
examples : Examples
The examples to score, either a tuple of (golds, preds) or a list of
spacy.training.Example objects
candidate_getter : AsList[RelationCandidateGetter]
The candidate getters to use to extract the possible relations from the
documents. Each candidate getter should be a dictionary with the keys
"head", "tail", and "labels". The "head" and "tail" keys should be
SpanGetterArg objects, and the "labels" key should be a list of strings
for these head-tail pairs.
micro_key : str
The key to use to store the micro-averaged results for spans of all types
filter_expr : Optional[str]
The filter expression to use to filter the documents

Returns
-------
Dict[str, float]
"""
examples = make_examples(examples)
if filter_expr is not None:
filter_fn = eval(f"lambda doc: {filter_expr}")
examples = [eg for eg in examples if filter_fn(eg.reference)]
# annotations: {label -> preds, golds, pred_with_probs}
annotations = defaultdict(lambda: (set(), set(), dict()))
annotations[micro_key] = (set(), set(), dict())
total_pred_count = 0
total_gold_count = 0

for candidate in candidate_getter:
head_getter = candidate["head"]
tail_getter = candidate["tail"]
labels = candidate["labels"]
for eg_idx, eg in enumerate(examples):
pred_heads = [
((h.start, h.end, h.label_), h)
for h in get_spans(eg.predicted, head_getter)
]
pred_tails = [
((t.start, t.end, t.label_), t)
for t in get_spans(eg.predicted, tail_getter)
]
for (h_key, head), (t_key, tail) in product(pred_heads, pred_tails):
total_pred_count += 1
for label in labels:
if tail in head._.rel.get(label, ()):
annotations[label][0].add((eg_idx, h_key, t_key, label))
annotations[micro_key][0].add((eg_idx, h_key, t_key, label))

gold_heads = [
((h.start, h.end, h.label_), h)
for h in get_spans(eg.reference, head_getter)
]
gold_tails = [
((t.start, t.end, t.label_), t)
for t in get_spans(eg.reference, tail_getter)
]
for (h_key, head), (t_key, tail) in product(gold_heads, gold_tails):
total_gold_count += 1
for label in labels:
if tail in head._.rel.get(label, ()):
annotations[label][1].add((eg_idx, h_key, t_key, label))
annotations[micro_key][1].add((eg_idx, h_key, t_key, label))

if total_pred_count != total_gold_count:
raise ValueError(
f"Number of predicted and gold candidate pairs differ: {total_pred_count} "
f"!= {total_gold_count}. Make sure that you are running your span "
"attribute classification pipe on the gold annotations, and not spans "
"predicted by another NER pipe in your model."
)

return {
name: {
**prf(pred, gold),
# "ap": average_precision(pred_with_prob, gold),
}
for name, (pred, gold, pred_with_prob) in annotations.items()
}


@registry.metrics.register("eds.relations")
class RelationsMetric:
def __init__(
self,
candidate_getter: AsList[RelationCandidateGetter],
micro_key: str = "micro",
filter_expr: Optional[str] = None,
):
self.candidate_getter = candidate_getter
self.micro_key = micro_key
self.filter_expr = filter_expr

__init__.__doc__ = relations_scorer.__doc__

def __call__(self, *examples: Any):
return relations_scorer(
examples,
candidate_getter=self.candidate_getter,
micro_key=self.micro_key,
filter_expr=self.filter_expr,
)
1 change: 1 addition & 0 deletions edsnlp/pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .qualifiers.reported_speech.factory import create_component as reported_speech
from .qualifiers.reported_speech.factory import create_component as rspeech
from .trainable.ner_crf.factory import create_component as ner_crf
from .trainable.relation_detector_ffn.factory import create_component as relation_detector_ffn
from .trainable.biaffine_dep_parser.factory import create_component as biaffine_dep_parser
from .trainable.extractive_qa.factory import create_component as extractive_qa
from .trainable.span_classifier.factory import create_component as span_classifier
Expand Down
33 changes: 33 additions & 0 deletions edsnlp/pipes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from edsnlp.core import PipelineProtocol
from edsnlp.core.registries import CurriedFactory
from edsnlp.utils.span_getters import (
RelationCandidateGetter,
SpanGetter, # noqa: F401
SpanGetterArg, # noqa: F401
SpanSetter,
Expand All @@ -23,6 +24,7 @@
validate_span_getter, # noqa: F401
validate_span_setter,
)
from edsnlp.utils.typing import AsList


def value_getter(span: Span):
Expand Down Expand Up @@ -203,3 +205,34 @@ def qualifiers(self): # pragma: no cover
@qualifiers.setter
def qualifiers(self, value): # pragma: no cover
self.attributes = value


class BaseRelationDetectorComponent(BaseComponent, abc.ABC):
head_getter: SpanGetter
tail_getter: SpanGetter
labels: List[str]

def __init__(
self,
nlp: PipelineProtocol = None,
name: str = None,
*args,
candidate_getter: AsList[RelationCandidateGetter],
**kwargs,
):
super().__init__(nlp, name, *args, **kwargs)
self.candidate_getter = [
{
"head": validate_span_getter(candidate["head"]),
"tail": validate_span_getter(candidate["tail"]),
"labels": candidate["labels"],
}
for candidate in candidate_getter
]
self.labels = sorted(
{
label
for candidate in self.candidate_getter
for label in candidate["labels"]
}
)
17 changes: 12 additions & 5 deletions edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,18 @@ def preprocess(
begins = []
ends = []

contexts_to_idx = {span: i for i, span in enumerate(contexts)}
contexts_to_idx = {}
for ctx in contexts:
contexts_to_idx[ctx] = len(contexts_to_idx)
dedup_contexts = sorted(contexts_to_idx, key=contexts_to_idx.get)
assert not pre_aligned or len(spans) == len(contexts), (
"When `pre_aligned` is True, the number of spans and contexts must be the "
"same."
)
aligned_contexts = (
[[c] for c in contexts]
[[c] for c in dedup_contexts]
if pre_aligned
else align_spans(contexts, spans, sort_by_overlap=True)
else align_spans(dedup_contexts, spans, sort_by_overlap=True)
)
for i, (span, ctx) in enumerate(zip(spans, aligned_contexts)):
if len(ctx) == 0 or ctx[0].start > span.start or ctx[0].end < span.end:
Expand All @@ -143,12 +146,16 @@ def preprocess(
sequence_idx.append(contexts_to_idx[ctx[0]])
begins.append(span.start - start)
ends.append(span.end - start)
assert begins[-1] >= 0, f"Begin offset is negative: {span.text}"
assert ends[-1] <= len(ctx[0]), f"End offset is out of bounds: {span.text}"
return {
"begins": begins,
"ends": ends,
"sequence_idx": sequence_idx,
"num_sequences": len(contexts),
"embedding": self.embedding.preprocess(doc, contexts=contexts, **kwargs),
"num_sequences": len(dedup_contexts),
"embedding": self.embedding.preprocess(
doc, contexts=dedup_contexts, **kwargs
),
"stats": {"spans": len(begins)},
}

Expand Down
1 change: 1 addition & 0 deletions edsnlp/pipes/trainable/relation_detector_ffn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .factory import create_component
9 changes: 9 additions & 0 deletions edsnlp/pipes/trainable/relation_detector_ffn/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from edsnlp import registry

from .relation_detector_ffn import RelationDetectorFFN

create_component = registry.factory.register(
"eds.relation_detector_ffn",
assigns=[],
deprecated=[],
)(RelationDetectorFFN)
Loading