11import os
22import warnings
3+ from collections import deque
34from typing import (
45 Any ,
56 Callable ,
@@ -416,8 +417,8 @@ def process(self, doc):
416417 def pipe (self , docs : Iterable [Doc ]) -> Iterable [Doc ]:
417418 """
418419 Extract entities concurrently, but yield results in the same order
419- as the input `docs`. Up to `max_concurrent_requests` documents are
420- processed in parallel.
420+ as the input `docs`. Up to `max_concurrent_requests` span-level
421+ requests are processed in parallel.
421422
422423 Parameters
423424 ----------
@@ -430,48 +431,70 @@ def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
430431 Processed documents in the original input order.
431432 """
432433 if self .max_concurrent_requests <= 1 : # pragma: no cover
433- for ctx in docs :
434- yield self .process (ctx )
434+ for doc in docs :
435+ yield self .process (doc )
435436 return
436437
437438 worker = AsyncRequestWorker .instance ()
438439
439- # Documents sent to the worker, waiting for results
440+ # Documents that are currently being processed, keyed by their
441+ # index in the input stream.
440442 pending_docs : Dict [int , Doc ] = {}
441- # Documents already processed, waiting to be yielded in order
443+ # Number of remaining contexts to process for each document.
444+ remaining_ctx_counts : Dict [int , int ] = {}
445+ # Fully processed documents waiting to be yielded in order.
442446 buffer : Dict [int , Doc ] = {}
443447 next_to_yield = 0
444- in_flight : Dict [int , int ] = {}
445448
446- ctx_iter = enumerate (
447- ctx for doc in docs for ctx in get_spans (doc , self .context_getter )
448- )
449+ # In-flight LLM requests: task_id -> (doc_index, context)
450+ in_flight : Dict [int , Tuple [int , Any ]] = {}
449451
450- for _ in range (self .max_concurrent_requests ):
451- try :
452- i , ctx = next (ctx_iter )
453- except StopIteration :
454- break
455- messages = self .build_prompt (ctx )
456- task_id = worker .submit (self ._llm_request_coro (messages ))
457- in_flight [task_id ] = i
458- pending_docs [i ] = ctx
452+ docs_iter = enumerate (docs )
453+ ctx_queue : "deque[Tuple[int, Any]]" = deque ()
454+
455+ def enqueue_new_docs () -> None :
456+ # Fill the context queue up to `max_concurrent_requests`
457+ nonlocal docs_iter
458+ while len (ctx_queue ) < self .max_concurrent_requests :
459+ try :
460+ doc_idx , doc = next (docs_iter )
461+ except StopIteration :
462+ break
463+
464+ pending_docs [doc_idx ] = doc
465+ contexts = list (get_spans (doc , self .context_getter ))
466+
467+ if not contexts :
468+ remaining_ctx_counts [doc_idx ] = 0
469+ buffer [doc_idx ] = doc
470+ else :
471+ remaining_ctx_counts [doc_idx ] = len (contexts )
472+ for ctx in contexts :
473+ ctx_queue .append ((doc_idx , ctx ))
474+
475+ def submit_until_full () -> None :
476+ while len (in_flight ) < self .max_concurrent_requests and ctx_queue :
477+ doc_idx , ctx = ctx_queue .popleft ()
478+ messages = self .build_prompt (ctx )
479+ task_id = worker .submit (self ._llm_request_coro (messages ))
480+ in_flight [task_id ] = (doc_idx , ctx )
481+
482+ enqueue_new_docs ()
483+ submit_until_full ()
459484
460485 while in_flight :
461486 done_task_id = worker .wait_for_any (in_flight .keys ())
462487 result = worker .pop_result (done_task_id )
463- i = in_flight .pop (done_task_id )
464- ctx = pending_docs .pop (i )
488+ doc_idx , ctx = in_flight .pop (done_task_id )
465489
466490 if result is None :
467- buffer [ i ] = ctx
491+ pass
468492 else :
469493 res , err = result
470494 if err is not None :
471495 self ._handle_err (
472- f"[llm_markup_extractor] failed for document #{ i } : { err !r} "
496+ f"[llm_markup_extractor] failed for doc #{ doc_idx } : { err !r} "
473497 )
474- buffer [i ] = ctx
475498 else :
476499 try :
477500 self .apply_markup_to_doc_ (ctx , str (res ))
@@ -480,23 +503,16 @@ def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]:
480503
481504 traceback .print_exc ()
482505 self ._handle_err (
483- f"[llm_markup_extractor] "
484- f"failed to parse result for document #{ i } : { e !r} in "
485- f"{ res !r} "
506+ f"[llm_markup_extractor] failed to parse result for doc "
507+ f"#{ doc_idx } : { e !r} in { res !r} "
486508 )
487- buffer [i ] = ctx
488509
489- while True :
490- try :
491- if len (in_flight ) >= self .max_concurrent_requests :
492- break
493- i2 , d2 = next (ctx_iter )
494- except StopIteration :
495- break
496- messages2 = self .build_prompt (d2 )
497- task_id2 = worker .submit (self ._llm_request_coro (messages2 ))
498- in_flight [task_id2 ] = i2
499- pending_docs [i2 ] = d2
510+ remaining_ctx_counts [doc_idx ] -= 1
511+ if remaining_ctx_counts [doc_idx ] == 0 :
512+ buffer [doc_idx ] = pending_docs .pop (doc_idx )
513+
514+ enqueue_new_docs ()
515+ submit_until_full ()
500516
501517 while next_to_yield in buffer :
502518 yield buffer .pop (next_to_yield )
@@ -512,7 +528,7 @@ def _llm_request_sync(self, messages) -> str:
512528 messages = messages ,
513529 ** self .api_kwargs ,
514530 )
515- return response .choices [0 ].message .content
531+ return str ( response .choices [0 ].message .content )
516532
517533 def _llm_request_coro (self , messages ) -> Coroutine [Any , Any , str ]:
518534 async def _coro ():
0 commit comments