Skip to content

Commit f36c5b7

Browse files
committed
feat: new eds.relation_detector_ffn trainable component
1 parent 5d9c033 commit f36c5b7

File tree

6 files changed

+135
-103
lines changed

6 files changed

+135
-103
lines changed

edsnlp/data/converters.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ def __call__(self, obj, tokenizer=None):
321321
for head in entities[from_entity_id]:
322322
for tail in entities[to_entity_id]:
323323
head._.rel.setdefault(relation_label, set()).add(tail)
324-
print("NEW REL FROM", head, "TO", tail, "WITH", relation_label)
325324

326325
return doc
327326

edsnlp/metrics/relations.py

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44

55
from edsnlp import registry
66
from edsnlp.metrics import Examples, make_examples, prf
7-
from edsnlp.utils.span_getters import SpanGetterArg, get_spans
7+
from edsnlp.utils.span_getters import RelationCandidateGetter, get_spans
88
from edsnlp.utils.typing import AsList
99

1010

1111
def relations_scorer(
1212
examples: Examples,
13-
head_getter: SpanGetterArg,
14-
tail_getter: SpanGetterArg,
15-
labels: AsList[str],
13+
candidate_getter: AsList[RelationCandidateGetter],
1614
micro_key: str = "micro",
1715
filter_expr: Optional[str] = None,
1816
):
@@ -24,12 +22,12 @@ def relations_scorer(
2422
examples : Examples
2523
The examples to score, either a tuple of (golds, preds) or a list of
2624
spacy.training.Example objects
27-
head_getter : SpanGetterArg
28-
The span getter to use to extract the relation heads from the document
29-
tail_getter : SpanGetterArg
30-
The span getter to use to extract the relation tails from the document
31-
labels : Sequence[str]
32-
The labels of the relations to evaluate
25+
candidate_getter : AsList[RelationCandidateGetter]
26+
The candidate getters to use to extract the possible relations from the
27+
documents. Each candidate getter should be a dictionary with the keys
28+
"head", "tail", and "labels". The "head" and "tail" keys should be
29+
SpanGetterArg objects, and the "labels" key should be a list of strings
30+
for these head-tail pairs.
3331
micro_key : str
3432
The key to use to store the micro-averaged results for spans of all types
3533
filter_expr : Optional[str]
@@ -49,36 +47,40 @@ def relations_scorer(
4947
total_pred_count = 0
5048
total_gold_count = 0
5149

52-
for eg_idx, eg in enumerate(examples):
53-
pred_heads = [
54-
((h.start, h.end, h.label_), h)
55-
for h in get_spans(eg.predicted, head_getter)
56-
]
57-
pred_tails = [
58-
((t.start, t.end, t.label_), t)
59-
for t in get_spans(eg.predicted, tail_getter)
60-
]
61-
for (h_key, head), (t_key, tail) in product(pred_heads, pred_tails):
62-
total_pred_count += 1
63-
for label in labels:
64-
if tail in head._.rel.get(label, ()):
65-
annotations[label][0].add((eg_idx, h_key, t_key, label))
66-
annotations[micro_key][0].add((eg_idx, h_key, t_key, label))
50+
for candidate in candidate_getter:
51+
head_getter = candidate["head"]
52+
tail_getter = candidate["tail"]
53+
labels = candidate["labels"]
54+
for eg_idx, eg in enumerate(examples):
55+
pred_heads = [
56+
((h.start, h.end, h.label_), h)
57+
for h in get_spans(eg.predicted, head_getter)
58+
]
59+
pred_tails = [
60+
((t.start, t.end, t.label_), t)
61+
for t in get_spans(eg.predicted, tail_getter)
62+
]
63+
for (h_key, head), (t_key, tail) in product(pred_heads, pred_tails):
64+
total_pred_count += 1
65+
for label in labels:
66+
if tail in head._.rel.get(label, ()):
67+
annotations[label][0].add((eg_idx, h_key, t_key, label))
68+
annotations[micro_key][0].add((eg_idx, h_key, t_key, label))
6769

68-
gold_heads = [
69-
((h.start, h.end, h.label_), h)
70-
for h in get_spans(eg.reference, head_getter)
71-
]
72-
gold_tails = [
73-
((t.start, t.end, t.label_), t)
74-
for t in get_spans(eg.reference, tail_getter)
75-
]
76-
for (h_key, head), (t_key, tail) in product(gold_heads, gold_tails):
77-
total_gold_count += 1
78-
for label in labels:
79-
if tail in head._.rel.get(label, ()):
80-
annotations[label][1].add((eg_idx, h_key, t_key, label))
81-
annotations[micro_key][1].add((eg_idx, h_key, t_key, label))
70+
gold_heads = [
71+
((h.start, h.end, h.label_), h)
72+
for h in get_spans(eg.reference, head_getter)
73+
]
74+
gold_tails = [
75+
((t.start, t.end, t.label_), t)
76+
for t in get_spans(eg.reference, tail_getter)
77+
]
78+
for (h_key, head), (t_key, tail) in product(gold_heads, gold_tails):
79+
total_gold_count += 1
80+
for label in labels:
81+
if tail in head._.rel.get(label, ()):
82+
annotations[label][1].add((eg_idx, h_key, t_key, label))
83+
annotations[micro_key][1].add((eg_idx, h_key, t_key, label))
8284

8385
if total_pred_count != total_gold_count:
8486
raise ValueError(
@@ -101,15 +103,11 @@ def relations_scorer(
101103
class RelationsMetric:
102104
def __init__(
103105
self,
104-
head_getter: SpanGetterArg,
105-
tail_getter: SpanGetterArg,
106-
labels: AsList[str],
106+
candidate_getter: AsList[RelationCandidateGetter],
107107
micro_key: str = "micro",
108108
filter_expr: Optional[str] = None,
109109
):
110-
self.head_getter = head_getter
111-
self.tail_getter = tail_getter
112-
self.labels = labels
110+
self.candidate_getter = candidate_getter
113111
self.micro_key = micro_key
114112
self.filter_expr = filter_expr
115113

@@ -118,9 +116,7 @@ def __init__(
118116
def __call__(self, *examples: Any):
119117
return relations_scorer(
120118
examples,
121-
head_getter=self.head_getter,
122-
tail_getter=self.tail_getter,
123-
labels=self.labels,
119+
candidate_getter=self.candidate_getter,
124120
micro_key=self.micro_key,
125121
filter_expr=self.filter_expr,
126122
)

edsnlp/pipes/base.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from edsnlp.core import PipelineProtocol
1515
from edsnlp.core.registries import DraftPipe
1616
from edsnlp.utils.span_getters import (
17+
RelationCandidateGetter,
1718
SpanGetter, # noqa: F401
1819
SpanGetterArg, # noqa: F401
1920
SpanSetter,
@@ -23,6 +24,7 @@
2324
validate_span_getter, # noqa: F401
2425
validate_span_setter,
2526
)
27+
from edsnlp.utils.typing import AsList
2628

2729

2830
def value_getter(span: Span):
@@ -215,12 +217,22 @@ def __init__(
215217
nlp: PipelineProtocol = None,
216218
name: str = None,
217219
*args,
218-
head_getter: SpanGetterArg,
219-
tail_getter: SpanGetterArg,
220-
labels: List[str],
220+
candidate_getter: AsList[RelationCandidateGetter],
221221
**kwargs,
222222
):
223223
super().__init__(nlp, name, *args, **kwargs)
224-
self.head_getter: SpanGetter = validate_span_getter(head_getter) # type: ignore
225-
self.tail_getter: SpanGetter = validate_span_getter(tail_getter) # type: ignore
226-
self.labels = labels
224+
self.candidate_getter = [
225+
{
226+
"head": validate_span_getter(candidate["head"]),
227+
"tail": validate_span_getter(candidate["tail"]),
228+
"labels": candidate["labels"],
229+
}
230+
for candidate in candidate_getter
231+
]
232+
self.labels = sorted(
233+
{
234+
label
235+
for candidate in self.candidate_getter
236+
for label in candidate["labels"]
237+
}
238+
)

edsnlp/pipes/trainable/relation_detector_ffn/relation_detector_ffn.py

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
SpanEmbeddingComponent,
2626
WordEmbeddingComponent,
2727
)
28-
from edsnlp.utils.span_getters import SpanGetterArg, get_spans
28+
from edsnlp.utils.span_getters import RelationCandidateGetter, get_spans
2929
from edsnlp.utils.typing import AsList
3030

3131

@@ -57,13 +57,17 @@ def make_ranges(starts, ends):
5757
if 0 in ends.shape:
5858
return ends
5959
sizes = ends - starts
60+
mask = sizes > 0
6061
offsets = sizes.cumsum(0)
6162
offsets = offsets.roll(1)
6263
res = torch.ones(offsets[0], dtype=torch.long)
6364
offsets[0] = 0
64-
res[offsets] = starts
65-
res[offsets[1:]] -= ends[:-1] - 1
66-
return res.cumsum(0)
65+
masked_offsets = offsets[mask]
66+
starts = starts[mask]
67+
ends = ends[mask]
68+
res[masked_offsets] = starts
69+
res[masked_offsets[1:]] -= ends[:-1] - 1
70+
return res.cumsum(0), offsets
6771

6872

6973
logger = logging.getLogger(__name__)
@@ -85,41 +89,56 @@ def make_ranges(starts, ends):
8589
"""
8690

8791

92+
class MLP(torch.nn.Module):
93+
def __init__(
94+
self, input_dim: int, hidden_dim: int, output_dim: int, dropout_p: float = 0.0
95+
):
96+
super().__init__()
97+
self.hidden = torch.nn.Linear(input_dim, hidden_dim)
98+
self.output = torch.nn.Linear(hidden_dim, output_dim)
99+
self.dropout = torch.nn.Dropout(dropout_p)
100+
101+
def forward(self, x):
102+
x = self.hidden(x)
103+
x = F.relu(x)
104+
x = self.dropout(x)
105+
x = self.output(x)
106+
return x
107+
108+
88109
class RelationDetectorFFN(
89110
TorchComponent[BatchOutput, FrameBatchInput],
90111
BaseRelationDetectorComponent,
91112
):
92113
def __init__(
93114
self,
94115
nlp: Optional[PipelineProtocol] = None,
95-
name: str = "rel_scope",
116+
name: str = "relation_detector_ffn",
96117
*,
97118
span_embedding: SpanEmbeddingComponent,
98119
word_embedding: WordEmbeddingComponent,
99-
head_getter: SpanGetterArg,
100-
tail_getter: SpanGetterArg,
101-
labels: AsList[str],
102-
symmetric: bool = True,
120+
candidate_getter: AsList[RelationCandidateGetter],
121+
symmetric: bool = False,
122+
hidden_size: int = 128,
123+
dropout_p: float = 0.0,
103124
):
104125
super().__init__(
105126
nlp=nlp,
106127
name=name,
107-
head_getter=head_getter,
108-
tail_getter=tail_getter,
109-
labels=labels,
128+
candidate_getter=candidate_getter,
110129
)
111130
self.span_embedding = span_embedding
112131
self.word_embedding = word_embedding
113132
self.symmetric = symmetric
114-
# self.merge_mode = merge_mode
115133

116-
hidden_size = (
134+
embed_size = (
117135
self.span_embedding.output_size * 2 + self.word_embedding.output_size
118136
)
119137
with warnings.catch_warnings():
120138
warnings.simplefilter("ignore", UserWarning)
121139
# self.head_projection = torch.nn.Linear(hidden_size, hidden_size)
122140
# self.tail_projection = torch.nn.Linear(hidden_size, hidden_size)
141+
self.mlp = MLP(embed_size, hidden_size, hidden_size, dropout_p)
123142
self.classifier = torch.nn.Linear(hidden_size, len(self.labels))
124143

125144
@property
@@ -155,22 +174,23 @@ def preprocess(self, doc: Doc, supervised: int = False) -> Dict[str, Any]:
155174
rel_labels = []
156175

157176
all_spans = defaultdict(lambda: len(all_spans))
158-
head_spans = list(get_spans(doc, self.head_getter))
159-
tail_spans = list(get_spans(doc, self.tail_getter))
160-
161-
for head, tail in product(head_spans, tail_spans):
162-
rel_head_idx.append(all_spans[head])
163-
rel_tail_idx.append(all_spans[tail])
164-
if supervised:
165-
rel_labels.append(
166-
[
167-
(
168-
tail in head._.rel.get(lab, ())
169-
or (self.symmetric and head in tail._.rel.get(lab, ()))
170-
)
171-
for lab in self.labels
172-
]
173-
)
177+
178+
for candidate in self.candidate_getter:
179+
head_spans = list(get_spans(doc, candidate["head"]))
180+
tail_spans = list(get_spans(doc, candidate["tail"]))
181+
for head, tail in product(head_spans, tail_spans):
182+
rel_head_idx.append(all_spans[head])
183+
rel_tail_idx.append(all_spans[tail])
184+
if supervised:
185+
rel_labels.append(
186+
[
187+
(
188+
tail in head._.rel.get(lab, ())
189+
or (self.symmetric and head in tail._.rel.get(lab, ()))
190+
)
191+
for lab in self.labels
192+
]
193+
)
174194

175195
result = {
176196
"num_spans": len(all_spans),
@@ -231,17 +251,14 @@ def compute_inter_span_embeds(self, word_embeds, begins, ends, head_idx, tail_id
231251
0, dim, dtype=word_embeds.dtype, device=word_embeds.device
232252
)
233253

234-
flat_begins = torch.minimum(
235-
ends[head_idx],
236-
ends[tail_idx],
237-
)
238-
flat_ends = torch.maximum(
239-
begins[head_idx],
240-
begins[tail_idx],
254+
flat_begins = torch.minimum(ends[head_idx], ends[tail_idx])
255+
flat_ends = torch.maximum(begins[head_idx], begins[tail_idx])
256+
flat_begins, flat_ends = (
257+
torch.minimum(flat_begins, flat_ends),
258+
torch.maximum(flat_begins, flat_ends),
241259
)
242260
flat_embeds = word_embeds.view(-1, dim)
243-
flat_indices = make_ranges(flat_begins, flat_ends)
244-
flat_offsets = (flat_ends - flat_begins).cumsum(0).roll(1)
261+
flat_indices, flat_offsets = make_ranges(flat_begins, flat_ends)
245262
flat_offsets[0] = 0
246263
inter_span_embeds = torch.nn.functional.embedding_bag( # type: ignore
247264
input=flat_indices,
@@ -285,6 +302,7 @@ def forward(self, batch: FrameBatchInput) -> BatchOutput:
285302
],
286303
dim=-1,
287304
)
305+
rel_embeds = self.mlp(rel_embeds)
288306
logits = self.classifier(rel_embeds)
289307

290308
losses = pred = None
@@ -321,12 +339,8 @@ def postprocess(
321339
Returns
322340
-------
323341
"""
324-
all_heads = [
325-
prep["$spans"][idx] for prep in inputs for idx in prep["rel_heads"]
326-
]
327-
all_tails = [
328-
prep["$spans"][idx] for prep in inputs for idx in prep["rel_tails"]
329-
]
342+
all_heads = [p["$spans"][idx] for p in inputs for idx in p["rel_heads"]]
343+
all_tails = [p["$spans"][idx] for p in inputs for idx in p["rel_tails"]]
330344
for pair_idx, label_idx in results["pred"].nonzero(as_tuple=False).tolist():
331345
head = all_heads[pair_idx]
332346
tail = all_tails[pair_idx]

edsnlp/utils/span_getters.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from pydantic import NonNegativeInt
1616
from spacy.tokens import Doc, Span
17+
from typing_extensions import TypedDict
1718

1819
from edsnlp import registry
1920
from edsnlp.utils.filter import filter_spans
@@ -321,3 +322,13 @@ def __call__(self, span: Union[Doc, Span]) -> Union[Span, List[Span]]:
321322
end = max(end, max_end_sent)
322323

323324
return span.doc[start:end]
325+
326+
327+
RelationCandidateGetter = TypedDict(
328+
"CandidateGetter",
329+
{
330+
"head": SpanGetterArg,
331+
"tail": SpanGetterArg,
332+
"labels": AsList[str],
333+
},
334+
)

0 commit comments

Comments
 (0)