Skip to content

Commit 39df77a

Browse files
committed
fix: eds.llm_markup_extractor context splitting now yields full docs and not parts of docs
1 parent 75dc01d commit 39df77a

File tree

3 files changed

+103
-40
lines changed

3 files changed

+103
-40
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
- Don't pass seed to openai API calls (only as extra body)
88
- Default to alignment threshold = 0 (better recall) for LLM annotated markup alignment with the original text
9+
- Fix `eds.llm_markup_extractor` context splitting to yield full docs and not parts of docs
910

1011
## v0.19.0 (2025-10-04)
1112

edsnlp/pipes/llm/llm_markup_extractor/llm_markup_extractor.py

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import warnings
3+
from collections import deque
34
from typing import (
45
Any,
56
Callable,
@@ -416,8 +417,8 @@ def process(self, doc):
416417
def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
417418
"""
418419
Extract entities concurrently, but yield results in the same order
419-
as the input `docs`. Up to `max_concurrent_requests` documents are
420-
processed in parallel.
420+
as the input `docs`. Up to `max_concurrent_requests` span-level
421+
requests are processed in parallel.
421422
422423
Parameters
423424
----------
@@ -430,48 +431,70 @@ def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
430431
Processed documents in the original input order.
431432
"""
432433
if self.max_concurrent_requests <= 1: # pragma: no cover
433-
for ctx in docs:
434-
yield self.process(ctx)
434+
for doc in docs:
435+
yield self.process(doc)
435436
return
436437

437438
worker = AsyncRequestWorker.instance()
438439

439-
# Documents sent to the worker, waiting for results
440+
# Documents that are currently being processed, keyed by their
441+
# index in the input stream.
440442
pending_docs: Dict[int, Doc] = {}
441-
# Documents already processed, waiting to be yielded in order
443+
# Number of remaining contexts to process for each document.
444+
remaining_ctx_counts: Dict[int, int] = {}
445+
# Fully processed documents waiting to be yielded in order.
442446
buffer: Dict[int, Doc] = {}
443447
next_to_yield = 0
444-
in_flight: Dict[int, int] = {}
445448

446-
ctx_iter = enumerate(
447-
ctx for doc in docs for ctx in get_spans(doc, self.context_getter)
448-
)
449+
# In-flight LLM requests: task_id -> (doc_index, context)
450+
in_flight: Dict[int, Tuple[int, Any]] = {}
449451

450-
for _ in range(self.max_concurrent_requests):
451-
try:
452-
i, ctx = next(ctx_iter)
453-
except StopIteration:
454-
break
455-
messages = self.build_prompt(ctx)
456-
task_id = worker.submit(self._llm_request_coro(messages))
457-
in_flight[task_id] = i
458-
pending_docs[i] = ctx
452+
docs_iter = enumerate(docs)
453+
ctx_queue: "deque[Tuple[int, Any]]" = deque()
454+
455+
def enqueue_new_docs() -> None:
456+
# Fill the context queue up to `max_concurrent_requests`
457+
nonlocal docs_iter
458+
while len(ctx_queue) < self.max_concurrent_requests:
459+
try:
460+
doc_idx, doc = next(docs_iter)
461+
except StopIteration:
462+
break
463+
464+
pending_docs[doc_idx] = doc
465+
contexts = list(get_spans(doc, self.context_getter))
466+
467+
if not contexts:
468+
remaining_ctx_counts[doc_idx] = 0
469+
buffer[doc_idx] = doc
470+
else:
471+
remaining_ctx_counts[doc_idx] = len(contexts)
472+
for ctx in contexts:
473+
ctx_queue.append((doc_idx, ctx))
474+
475+
def submit_until_full() -> None:
476+
while len(in_flight) < self.max_concurrent_requests and ctx_queue:
477+
doc_idx, ctx = ctx_queue.popleft()
478+
messages = self.build_prompt(ctx)
479+
task_id = worker.submit(self._llm_request_coro(messages))
480+
in_flight[task_id] = (doc_idx, ctx)
481+
482+
enqueue_new_docs()
483+
submit_until_full()
459484

460485
while in_flight:
461486
done_task_id = worker.wait_for_any(in_flight.keys())
462487
result = worker.pop_result(done_task_id)
463-
i = in_flight.pop(done_task_id)
464-
ctx = pending_docs.pop(i)
488+
doc_idx, ctx = in_flight.pop(done_task_id)
465489

466490
if result is None:
467-
buffer[i] = ctx
491+
pass
468492
else:
469493
res, err = result
470494
if err is not None:
471495
self._handle_err(
472-
f"[llm_markup_extractor] failed for document #{i}: {err!r}"
496+
f"[llm_markup_extractor] failed for doc #{doc_idx}: {err!r}"
473497
)
474-
buffer[i] = ctx
475498
else:
476499
try:
477500
self.apply_markup_to_doc_(ctx, str(res))
@@ -480,23 +503,16 @@ def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
480503

481504
traceback.print_exc()
482505
self._handle_err(
483-
f"[llm_markup_extractor] "
484-
f"failed to parse result for document #{i}: {e!r} in "
485-
f"{res!r}"
506+
f"[llm_markup_extractor] failed to parse result for doc "
507+
f"#{doc_idx}: {e!r} in {res!r}"
486508
)
487-
buffer[i] = ctx
488509

489-
while True:
490-
try:
491-
if len(in_flight) >= self.max_concurrent_requests:
492-
break
493-
i2, d2 = next(ctx_iter)
494-
except StopIteration:
495-
break
496-
messages2 = self.build_prompt(d2)
497-
task_id2 = worker.submit(self._llm_request_coro(messages2))
498-
in_flight[task_id2] = i2
499-
pending_docs[i2] = d2
510+
remaining_ctx_counts[doc_idx] -= 1
511+
if remaining_ctx_counts[doc_idx] == 0:
512+
buffer[doc_idx] = pending_docs.pop(doc_idx)
513+
514+
enqueue_new_docs()
515+
submit_until_full()
500516

501517
while next_to_yield in buffer:
502518
yield buffer.pop(next_to_yield)
@@ -512,7 +528,7 @@ def _llm_request_sync(self, messages) -> str:
512528
messages=messages,
513529
**self.api_kwargs,
514530
)
515-
return response.choices[0].message.content
531+
return str(response.choices[0].message.content)
516532

517533
def _llm_request_coro(self, messages) -> Coroutine[Any, Any, str]:
518534
async def _coro():

tests/pipelines/llm/test_llm_markup_extractor.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,49 @@ def responder(**kw):
145145
docs = docs.to_iterable(converter="markup", preset="md")
146146
docs = list(docs)
147147
assert docs == md
148+
149+
150+
def test_context_getter_async():
151+
nlp = edsnlp.blank("eds")
152+
nlp.add_pipe("eds.normalizer")
153+
nlp.add_pipe("eds.sentences")
154+
nlp.add_pipe(
155+
eds.llm_markup_extractor(
156+
api_url="http://localhost:8080/v1",
157+
model="my-custom-model",
158+
prompt=PROMPT,
159+
max_concurrent_requests=2,
160+
context_getter="sents",
161+
)
162+
)
163+
164+
md = [
165+
"La patient souffre de [tuberculose](diagnosis). On débute une "
166+
"[antibiothérapie](treatment) dès ajd.",
167+
"Il a une [pneumonie](diagnosis) du thorax. C'est très grave.",
168+
]
169+
170+
counter = 0
171+
172+
def responder(messages, **kw):
173+
nonlocal counter
174+
counter += 1
175+
assert len(messages) == 2 # 1 system + 1 user
176+
res = (
177+
messages[-1]["content"]
178+
.replace("tuberculose", "<diagnosis>tuberculose</diagnosis>")
179+
.replace("antibiothérapie", "<treatment>antibiothérapie</treatment>")
180+
.replace("pneumonie", "<diagnosis>pneumonie</diagnosis>")
181+
.replace("grave", "grave</diagnosis>")
182+
)
183+
return res
184+
185+
with mock_llm_service(responder=responder):
186+
docs = edsnlp.data.from_iterable(md, converter="markup", preset="md")
187+
docs = docs.map(lambda x: x.text)
188+
docs = docs.map_pipeline(nlp)
189+
docs = docs.to_iterable(converter="markup", preset="md")
190+
docs = list(docs)
191+
assert docs == md
192+
193+
assert counter == 4

0 commit comments

Comments
 (0)