From 008bed07e53baf7ce87a5ce40daca2eb49398adf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Tue, 21 Jan 2025 22:27:17 +0100 Subject: [PATCH] fix: eds.split now keeps doc and span attributes in the sub-documents --- changelog.md | 1 + edsnlp/pipes/misc/split/split.py | 17 +++++++++++++---- tests/pipelines/misc/test_split.py | 20 +++++++++++++++++++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/changelog.md b/changelog.md index 0639d813e..910d14743 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,7 @@ ### Fixed - Support packaging with poetry 2.0 +- `eds.split` now keeps doc and span attributes in the sub-documents. # v0.15.0 (2024-12-13) diff --git a/edsnlp/pipes/misc/split/split.py b/edsnlp/pipes/misc/split/split.py index beab948f7..17db63384 100644 --- a/edsnlp/pipes/misc/split/split.py +++ b/edsnlp/pipes/misc/split/split.py @@ -64,14 +64,23 @@ def subset_doc(doc: Doc, start: int, end: int) -> Doc: ------- Doc """ - # TODO: review user_data copy strategy new_doc = doc[start:end].as_doc() - new_doc.user_data.update(doc.user_data) shifter = make_shifter(start, end, new_doc) - for key, val in list(new_doc.user_data.items()): - new_doc.user_data[key] = shifter(val) + char_beg = doc[start].idx if start < len(doc) else 0 + char_end = doc[end - 1].idx + len(doc[end - 1].text) + for k, val in list(doc.user_data.items()): + new_value = shifter(val) + if k[0] == "._." and new_value is not EMPTY: + new_doc.user_data[ + ( + k[0], + k[1], + None if k[2] is None else max(0, k[2] - char_beg), + None if k[3] is None else min(k[3] - char_beg, char_end - char_beg), + ) + ] = new_value for name, group in doc.spans.items(): new_doc.spans[name] = shifter(list(group)) diff --git a/tests/pipelines/misc/test_split.py b/tests/pipelines/misc/test_split.py index 34e3af460..d076bf358 100644 --- a/tests/pipelines/misc/test_split.py +++ b/tests/pipelines/misc/test_split.py @@ -1,3 +1,4 @@ +from spacy.tokens import Doc from spacy.tokens.span import Span import edsnlp @@ -12,16 +13,33 @@ def test_split_line_jump(): nlp.add_pipe("eds.matcher", config={"terms": {"test": "test"}}) doc = nlp(txt) Span.set_extension("test_dict", default={}, force=True) + Doc.set_extension("global_attr", default=None, force=True) + Span.set_extension("ent_attr", default=None, force=True) + doc._.global_attr = "global" doc.ents[0]._.test_dict = {"key": doc.ents[1]} + doc.ents[0]._.ent_attr = "ent-0" doc.ents[1]._.test_dict = {"key": doc.ents[0]} doc.ents[2]._.test_dict = {"key": doc.ents[0]} + doc.ents[2]._.ent_attr = "ent-2" + doc.spans["section"] = [doc[:]] + doc.spans["section"][0]._.ent_attr = "section" subdocs = list(eds.split(regex="\n\n")(doc)) + assert len(subdocs) == 2 assert subdocs[0].text == "This is a test. Another test.\n\n" - assert subdocs[1].text == "A third test!" + assert subdocs[0]._.global_attr == "global" assert subdocs[0].ents[0]._.test_dict["key"] == subdocs[0].ents[1] + assert subdocs[0].ents[0]._.ent_attr == "ent-0" assert subdocs[0].ents[1]._.test_dict["key"] == subdocs[0].ents[0] + assert subdocs[0].spans["section"][0].text == "This is a test. Another test.\n\n" + assert subdocs[0].spans["section"][0]._.ent_attr == "section" + + assert subdocs[1].text == "A third test!" + assert subdocs[1]._.global_attr == "global" assert subdocs[1].ents[0]._.test_dict == {} + assert subdocs[1].ents[0]._.ent_attr == "ent-2" + assert subdocs[1].spans["section"][0].text == "A third test!" + assert subdocs[1].spans["section"][0]._.ent_attr == "section" def test_filter():