diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index f9ffa22ead..0078f29b9c 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -1,7 +1,7 @@ First, when starting to develop, install the project with ```bash -pip install -e ".[dev]" +pip install -e ".[dev]" "mkdocs-eds@git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'" pre-commit install ``` diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index d182c0ae0c..de9ff57168 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -31,7 +31,7 @@ jobs: - name: Install dependencies run: | - pip install '.[docs]' + pip install '.[docs]' "mkdocs-eds@git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'" # uv venv # uv pip install '.[docs]' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 69f7f9a3d0..26b9ec45c6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -33,7 +33,7 @@ jobs: uses: pypa/cibuildwheel@v2.21.3 env: CIBW_ARCHS_MACOS: "x86_64 arm64" - CIBW_ENVIRONMENT: PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu + CIBW_ENVIRONMENT: PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu PIP_ONLY_BINARY=pyarrow - uses: actions/upload-artifact@v4 with: @@ -88,7 +88,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install '.[docs]' + pip install '.[docs]' "mkdocs-eds@git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'" - name: Set up Git run: | diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 5698496696..baaf8092bf 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -28,7 +28,7 @@ jobs: uses: pypa/cibuildwheel@v2.16.5 env: CIBW_ARCHS_MACOS: "x86_64 arm64" - CIBW_ENVIRONMENT: PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu + CIBW_ENVIRONMENT: PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu PIP_ONLY_BINARY=pyarrow build_sdist: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ff6662d46c..45117e1e5e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,22 +61,22 @@ jobs: cache: 'pip' - name: Install dependencies - run: pip install -e ".[dev]" + run: pip install -e ".[dev]" "mkdocs-eds@git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'" if: matrix.python-version != '3.9' && matrix.python-version != '3.10' && matrix.python-version != '3.11' && matrix.python-version != '3.12' - name: Install dependencies - run: pip install -e ".[dev,setup]" + run: pip install -e ".[dev,setup,llm]" "mkdocs-eds@git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'" if: matrix.python-version == '3.9' - name: Install dependencies # skip ML tests for 3.10 and 3.11 - run: pip install -e ".[dev-no-ml]" + run: pip install -e ".[dev-no-ml]" "mkdocs-eds@git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'" if: matrix.python-version == '3.10' || matrix.python-version == '3.11' || matrix.python-version == '3.12' - name: Test with Pytest on Python ${{ matrix.python-version }} env: UMLS_API_KEY: ${{ secrets.UMLS_API_KEY }} - run: coverage run -m pytest --ignore tests/test_docs.py + run: coverage run -m pytest --ignore tests/test_docs.py --ignore tests/pipelines/qualifiers/test_llm_qualifier.py --ignore tests/pipelines/qualifiers/test_llm_utils.py if: matrix.python-version != '3.9' - name: Test with Pytest on Python ${{ matrix.python-version }} @@ -118,7 +118,7 @@ jobs: cache: 'pip' - name: Install dependencies - run: pip install -e ".[docs]" + run: pip install -e ".[docs]" "mkdocs-eds@git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'" - name: Set up Git run: | @@ -197,7 +197,14 @@ jobs: - run: echo WEEK=$(date +%V) >>$GITHUB_ENV shell: bash + - name: Install library + run: | + pip install ".[ml,llm]" pytest + pytest tests/pipelines/test_pipelines.py + if: matrix.python-version != '3.7' + - name: Install library run: | pip install ".[ml]" pytest pytest tests/pipelines/test_pipelines.py + if: matrix.python-version == '3.7' diff --git a/README.md b/README.md index e27a7838dc..a09d870a5b 100644 --- a/README.md +++ b/README.md @@ -34,13 +34,13 @@ Check out our interactive [demo](https://aphp.github.io/edsnlp/demo/) ! You can install EDS-NLP via `pip`. We recommend pinning the library version in your projects, or use a strict package manager like [Poetry](https://python-poetry.org/). ```shell -pip install edsnlp==0.17.2 +pip install edsnlp==0.18.0 ``` or if you want to use the trainable components (using pytorch) ```shell -pip install "edsnlp[ml]==0.17.2" +pip install "edsnlp[ml]==0.18.0" ``` ### A first pipeline diff --git a/changelog.md b/changelog.md index ee6b315123..7ae94d4206 100644 --- a/changelog.md +++ b/changelog.md @@ -1,8 +1,10 @@ # Changelog -## Unreleased +## v0.18.0 (2025-09-02) -## Added +📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.19.0), in October 2025. Please upgrade to Python 3.10 or later. + +### Added - Added support for multiple loggers (`tensorboard`, `wandb`, `comet_ml`, `aim`, `mlflow`, `clearml`, `dvclive`, `csv`, `json`, `rich`) in `edsnlp.train` via the `logger` parameter. Default is [`json` and `rich`] for backward compatibility. - Sub batch sizes for gradient accumulation can now be defined as simple "splits" of the original batch, e.g. `batch_size = 10000 tokens` and `sub_batch_size = 5 splits` to accumulate batches of 2000 tokens. @@ -12,7 +14,7 @@ - New `Training a span classifier` tutorial, and reorganized deep-learning docs - `ScheduledOptimizer` now warns when a parameter selector does not match any parameter. -## Fixed +### Fixed - `use_section` in `eds.history` should now correctly handle cases when there are other sections following history sections. - Added clickable snippets in the documentation for more registered functions @@ -22,7 +24,7 @@ - :ambulance: Until now, `post_init` was applied **after** the instantiation of the optimizer : if the model discovered new labels, and therefore changed its parameter tensors to reflect that, these new tensors were not taken into account by the optimizer, which could likely lead to subpar performance. Now, `post_init` is applied **before** the optimizer is instantiated, so that the optimizer can correctly handle the new tensors. - Added missing entry points for readers and writers in the registry, including `write_parquet` and support for `polars` in `pyproject.toml`. Now all implemented readers and writers are correctly registered as entry points. -## Changed +### Changed - Sections cues in `eds.history` are now section titles, and not the full section. - :boom: Validation metrics are now found under the root field `validation` in the training logs (e.g. `metrics['validation']['ner']['micro']['f']`) diff --git a/contributing.md b/contributing.md index fcb6e46659..3730a68cc3 100644 --- a/contributing.md +++ b/contributing.md @@ -24,7 +24,7 @@ $ python -m venv venv $ source venv/bin/activate # Install the package with common, dev, setup dependencies in editable mode -$ pip install -e '.[dev,setup]' +$ pip install -e '.[dev,setup]' "mkdocs-eds@git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'" # And build resources $ python scripts/conjugate_verbs.py ``` @@ -107,13 +107,13 @@ Most modern editors propose extensions that will format files on save. Make sure to document your improvements, both within the code with comprehensive docstrings, as well as in the documentation itself if need be. -We use `MkDocs` for EDS-NLP's documentation. You can checkout the changes you make with: +We use `MkDocs` for EDS-NLP's documentation. You can check out the changes you make with:
```console # Install the requirements -$ pip install -e '.[docs]' +$ pip install -e '.[docs]' "mkdocs-eds@git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'" ---> 100% color:green Installation successful diff --git a/docs/index.md b/docs/index.md index 3e6152a4ab..c1869d29bb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -15,13 +15,13 @@ Check out our interactive [demo](https://aphp.github.io/edsnlp/demo/) ! You can install EDS-NLP via `pip`. We recommend pinning the library version in your projects, or use a strict package manager like [Poetry](https://python-poetry.org/). ```{: data-md-color-scheme="slate" } -pip install edsnlp==0.17.2 +pip install edsnlp==0.18.0 ``` or if you want to use the trainable components (using pytorch) ```{: data-md-color-scheme="slate" } -pip install "edsnlp[ml]==0.17.2" +pip install "edsnlp[ml]==0.18.0" ``` ### A first pipeline diff --git a/docs/pipes/qualifiers/llm-qualifier.md b/docs/pipes/qualifiers/llm-qualifier.md new file mode 100644 index 0000000000..ba85519f4f --- /dev/null +++ b/docs/pipes/qualifiers/llm-qualifier.md @@ -0,0 +1,26 @@ +## LLM Span Classifier {: #edsnlp.pipes.qualifiers.llm.factory.create_component } + +::: edsnlp.pipes.qualifiers.llm.factory.create_component + options: + heading_level: 3 + show_bases: false + show_source: false + only_class_level: true + +## APIParams {: #edsnlp.pipes.qualifiers.llm.llm_qualifier.APIParams } + +::: edsnlp.pipes.qualifiers.llm.llm_qualifier.APIParams + options: + heading_level: 3 + show_bases: false + show_source: false + only_class_level: true + +## PromptConfig {: #edsnlp.pipes.qualifiers.llm.llm_qualifier.PromptConfig } + +::: edsnlp.pipes.qualifiers.llm.llm_qualifier.PromptConfig + options: + heading_level: 3 + show_bases: false + show_source: false + only_class_level: true diff --git a/docs/tutorials/qualifying-entities-with-llm.md b/docs/tutorials/qualifying-entities-with-llm.md new file mode 100644 index 0000000000..7465a42368 --- /dev/null +++ b/docs/tutorials/qualifying-entities-with-llm.md @@ -0,0 +1,229 @@ +# Using a LLM as a span qualifier +In this tutorial we woud learn how to use the `LLMSpanClassifier` pipe to qualify spans. +You should install the extra dependencies before in a python environment (python>='3.8'): +```bash +pip install edsnlp[llm] +``` + +## Using a local LLM server +We suppose that there is an available LLM server compatible with OpenAI API. +For example, using the library vllm you can launch an LLM server as follows in command line: +```bash +vllm serve Qwen/Qwen3-8B --port 8000 --enable-prefix-caching --tensor-parallel-size 1 --max-num-seqs=10 --max-num-batched-tokens=35000 +``` + +## Using an external API +You can also use the [Openai API](https://openai.com/index/openai-api/) or the [Groq API](https://groq.com/). + +!!! warning + + As you are probably working with sensitive medical data, please check whether you can use an external API or if you need to expose an API in your own infrastructure. + +## Import dependencies +```{ .python .no-check } +from datetime import datetime + +import pandas as pd + +import edsnlp +import edsnlp.pipes as eds +from edsnlp.pipes.qualifiers.llm.llm_qualifier import LLMSpanClassifier +from edsnlp.utils.span_getters import make_span_context_getter +``` +## Define prompt and examples +```{ .python .no-check } +task_prompts = { + 0: { + "normalized_task_name": "biopsy_procedure", + "system_prompt": "You are a medical assistant and you will help answering questions about dates present in clinical notes. Don't answer reasoning. " + + "We are interested in detecting biopsy dates (either procedure, analysis or result). " + + "You should answer in a JSON object following this schema {'biopsy':bool}. " + + "If there is not enough information, answer {'biopsy':'False'}." + + "\n\n#### Examples:\n", + "examples": [ + ( + "07/12/2020", + "07/12/2020 : Anapath / biopsies rectales : Muqueuse rectale normale sous réserve de fragments de petite taille.", + "{'biopsy':'True'}", + ), + ( + "24/12/2021", + "Chirurgie 24/12/2021 : Colectomie gauche + anastomose colo rectale + clearance hépatique gauche (une méta posée sur", + "{'biopsy':'False'}", + ), + ], + "prefix_prompt": "\nDetermine if '{span}' corresponds to a biopsy date. The text is as follows:\n<<< ", + "suffix_prompt": " >>>", + "json_schema": { + "properties": { + "biopsy": {"title": "Biopsy", "type": "boolean"}, + }, + "required": [ + "biopsy", + ], + "title": "DateModel", + "type": "object", + }, + "response_mapping": { + "(?i)(oui)|(yes)|(true)": "1", + "(?i)(non)|(no)|(false)|(don't)|(not)": "0", + }, + }, +} +``` + +## Format these examples for few-shot learning +```{ .python .no-check } +def format_examples(raw_examples, prefix_prompt, suffix_prompt): + examples = [] + + for date, context, answer in raw_examples: + prompt = prefix_prompt.format(span=date) + context + suffix_prompt + examples.append((prompt, answer)) + + return examples +``` + +## Set parameters and prompts +```{ .python .no-check } +# Set prompt +prompt_id = 0 +raw_examples = task_prompts.get(prompt_id).get("examples") +prefix_prompt = task_prompts.get(prompt_id).get("prefix_prompt") +user_prompt = task_prompts.get(prompt_id).get("user_prompt") +system_prompt = task_prompts.get(prompt_id).get("system_prompt") +suffix_prompt = task_prompts.get(prompt_id).get("suffix_prompt") +examples = format_examples(raw_examples, prefix_prompt, suffix_prompt) + +# Define JSON schema +response_format = { + "type": "json_schema", + "json_schema": { + "name": "DateModel", + # "strict": True, + "schema": task_prompts.get(prompt_id)["json_schema"], + }, +} + +# Set parameters +response_mapping = None +max_tokens = 200 +extra_body = { + # "chat_template_kwargs": {"enable_thinking": False}, +} +temperature = 0 +``` + +=== "For local serving" + + ```{ .python .no-check } + ### For local serving + model_name = "Qwen/Qwen3-8B" + api_url = "http://localhost:8000/v1" + api_key = "EMPTY_API_KEY" + ``` + + +=== "Using the Groq API" + !!! warning + ⚠️ This section involves the use of an external API. Please ensure you have the necessary credentials and understand the potential risks associated with external API usage. + + ```{ .python .no-check } + ### Using Groq API + model_name = "openai/gpt-oss-20b" + api_url = "https://api.groq.com/openai/v1" + api_key = "TOKEN" ## your API KEY + ``` + +## Define the pipeline +```{ .python .no-check } +nlp = edsnlp.blank("eds") +nlp.add_pipe("sentencizer") +nlp.add_pipe(eds.dates()) +nlp.add_pipe( + LLMSpanClassifier( + name="llm", + model=model_name, + span_getter=["dates"], + attributes={"_.biopsy_procedure": True}, + context_getter=make_span_context_getter( + context_sents=(3, 3), + context_words=(1, 1), + ), + prompt=dict( + system_prompt=system_prompt, + user_prompt=user_prompt, + prefix_prompt=prefix_prompt, + suffix_prompt=suffix_prompt, + examples=examples, + ), + api_params=dict( + max_tokens=max_tokens, + temperature=temperature, + response_format=response_format, + extra_body=extra_body, + ), + api_url=api_url, + api_key=api_key, + response_mapping=response_mapping, + n_concurrent_tasks=4, + ) +) +``` + +## Apply it on a document + +```{ .python .no-check } +# Let's try with a fake LLM generated text +text = """ +Centre Hospitalier Départemental – RCP Prostate – 20/02/2025 + +M. Bernard P., 69 ans, retraité, consulte après avoir noté une faiblesse du jet urinaire et des levers nocturnes répétés depuis un an. PSA à 15,2 ng/mL (05/02/2025). TR : nodule ferme sur lobe gauche. + +IRM multiparamétrique du 10/02/2025 : lésion PIRADS 5, 2,1 cm, atteinte de la capsule suspectée. +Biopsies du 12/02/2025 : adénocarcinome Gleason 4+4=8, toutes les carottes gauches positives. +Scanner TAP et scintigraphie osseuse du 14/02 : absence de métastases viscérales ou osseuses. + +En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. Décision : radiothérapie externe + hormonothérapie longue (24 mois). Planification de la simulation scanner le 25/02. +""" +``` + +```{ .python .no-check } +t0 = datetime.now() +doc = nlp(text) +t1 = datetime.now() +print("Execution time", t1 - t0) + +for span in doc.spans["dates"]: + print(span, span._.biopsy_procedure) +``` + +Lets check the type +```{ .python .no-check } +type(span._.biopsy_procedure) +``` +# Apply on multiple documents +```{ .python .no-check } +texts = [ + text, +] * 2 + +notes = pd.DataFrame({"note_id": range(len(texts)), "note_text": texts}) +docs = edsnlp.data.from_pandas(notes, nlp=nlp, converter="omop") +predicted_docs = docs.map_pipeline(nlp, 2) +``` + +```{ .python .no-check } +t0 = datetime.now() +note_nlp = edsnlp.data.to_pandas( + predicted_docs, + converter="ents", + span_getter="dates", + span_attributes=[ + "biopsy_procedure", + ], +) +t1 = datetime.now() +print("Execution time", t1 - t0) +note_nlp.head() +``` diff --git a/edsnlp/__init__.py b/edsnlp/__init__.py index 31cce4bd8a..372d40bd68 100644 --- a/edsnlp/__init__.py +++ b/edsnlp/__init__.py @@ -15,7 +15,7 @@ import edsnlp.pipes from . import reducers -__version__ = "0.17.2" +__version__ = "0.18.0" BASE_DIR = Path(__file__).parent diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py index aea3f0f088..91d48bc6a5 100644 --- a/edsnlp/pipes/__init__.py +++ b/edsnlp/pipes/__init__.py @@ -74,6 +74,7 @@ from .qualifiers.negation.factory import create_component as negation from .qualifiers.reported_speech.factory import create_component as reported_speech from .qualifiers.reported_speech.factory import create_component as rspeech + from .qualifiers.llm.factory import create_component as llm_span_qualifier from .trainable.ner_crf.factory import create_component as ner_crf from .trainable.biaffine_dep_parser.factory import create_component as biaffine_dep_parser from .trainable.extractive_qa.factory import create_component as extractive_qa diff --git a/edsnlp/pipes/qualifiers/llm/__init__.py b/edsnlp/pipes/qualifiers/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/edsnlp/pipes/qualifiers/llm/factory.py b/edsnlp/pipes/qualifiers/llm/factory.py new file mode 100644 index 0000000000..c65132595f --- /dev/null +++ b/edsnlp/pipes/qualifiers/llm/factory.py @@ -0,0 +1,7 @@ +from edsnlp.core import registry + +from .llm_qualifier import LLMSpanClassifier + +create_component = registry.factory.register( + "eds.llm_span_qualifier", +)(LLMSpanClassifier) diff --git a/edsnlp/pipes/qualifiers/llm/llm_qualifier.py b/edsnlp/pipes/qualifiers/llm/llm_qualifier.py new file mode 100644 index 0000000000..4683ce3b6e --- /dev/null +++ b/edsnlp/pipes/qualifiers/llm/llm_qualifier.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import logging +import re +from typing import ( + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) + +from spacy.tokens import Doc, Span +from typing_extensions import TypedDict + +from edsnlp.core.pipeline import Pipeline +from edsnlp.pipes.base import BaseSpanAttributeClassifierComponent +from edsnlp.pipes.qualifiers.llm.llm_utils import ( + AsyncLLM, + create_prompt_messages, +) +from edsnlp.utils.asynchronous import run_async +from edsnlp.utils.bindings import ( + BINDING_SETTERS, + Attributes, + AttributesArg, +) +from edsnlp.utils.span_getters import SpanGetterArg, get_spans + +logger = logging.getLogger(__name__) + +LLMSpanClassifierBatchInput = TypedDict( + "LLMSpanClassifierBatchInput", + { + "queries": List[str], + }, +) +""" +queries: List[str] + List of queries to send to the LLM for classification. + Each query corresponds to a span and its context. +""" + +LLMSpanClassifierBatchOutput = TypedDict( + "LLMSpanClassifierBatchOutput", + { + "labels": Optional[Union[List[str], List[List[str]]]], + }, +) +""" +labels: Optional[Union[List[str], List[List[str]]]] + The predicted labels for each query. + If `n > 1`, this will be a list of lists, where each inner list contains the + predictions for a single query. + If `n == 1`, this will be a list of strings, where each string is the prediction + for a single query. + + If the API call fails or no predictions are made, this will be None. + If `n > 1`, it will be a list of None values for each query. + If `n == 1`, it will be a single None value. + +""" + + +class PromptConfig(TypedDict, total=False): + """ + Parameters + ---------- + system_prompt : Optional[str] + A system prompt to use for the LLM. This is a general prompt that will be + prepended to each query. This prompt will be passed under the `system` role + in the OpenAI API call. + Example: "You are a medical expert. Classify the following text." + If None, no system prompt will be used. + Note: This is not the same as the `user_prompt` parameter. + user_prompt : Optional[str] + A general prompt to use for all spans. This is a prompt that will be prepended + to each span's specific prompt. This will be passed under the `user` role + in the OpenAI API call. + prefix_prompt : Optional[str] + A prefix prompt to paste after the `user_prompt` and before the selected context + of the span (using the `context_getter`). + It will be formatted specifically for each span, using the `span` variable. + Example: "Is '{span}' a Colonoscopy (procedure) date?" + suffix_prompt: Optional[str] + A suffix prompt to append at the end of the prompt. + examples : Optional[List[Tuple[str, str]]] + A list of examples to use for the prompt. Each example is a tuple of + (input, output). The input is the text to classify and the output is the + expected classification. + If None, no examples will be used. + Example: [("This is a colonoscopy date.", "colonoscopy_date")] + """ + + system_prompt: Optional[str] + user_prompt: Optional[str] + prefix_prompt: Optional[str] + suffix_prompt: Optional[str] + examples: Optional[List[Tuple[str, str]]] + + +class APIParams(TypedDict, total=False): + """ + Parameters + ---------- + extra_body : Optional[Dict[str, Any]] + Additional body parameters to pass to the vLLM API. + This can be used to pass additional parameters to the model, such as + `reasoning_parser` or `enable_reasoning`. + response_format : Optional[Dict[str, Any]] + The response format to use for the vLLM API call. + This can be used to specify how the response should be formatted. + temperature : float + The temperature for the vLLM API call. Default is 0.0 (deterministic). + max_tokens : int + The maximum number of tokens to generate in the response. + Default is 50. + """ + + max_tokens: int + temperature: float + response_format: Optional[Dict[str, Any]] + extra_body: Optional[Dict[str, Any]] + + +class LLMSpanClassifier( + BaseSpanAttributeClassifierComponent, +): + """ + The `LLMSpanClassifier` component is a LLM attribute predictor. + In this context, the span classification task consists in assigning values (boolean, + strings or any object) to attributes/extensions of spans such as: + + - `span._.negation`, + - `span._.date.mode` + - `span._.cui` + + This pipe will use an LLM API to classify previously identified spans using + the context and instructions around each span. + + Check out the LLM classifier tutorial + for examples ! + + Python >= 3.8 is required. + + Parameters + ---------- + nlp : PipelineProtocol + The pipeline object + name : str + Name of the component + prompt : Optional[PromptConfig] + The prompt configuration to use for the LLM. + api_url : str + The base URL of the vLLM OpenAI-compatible server to call. + Default: "http://localhost:8000/v1" + model : str + The name of the model to use for classification. + Default: "Qwen/Qwen3-8B" + span_getter : SpanGetterArg + How to extract the candidate spans and the attributes to predict or train on. + context_getter : Optional[Union[Callable, SpanGetterArg]] + What context to use when computing the span embeddings (defaults to the whole + document). This can be: + + - a `SpanGetterArg` to retrieve contexts from a whole document. For example + `{"section": "conclusion"}` to only use the conclusion as context (you + must ensure that all spans produced by the `span_getter` argument do fall + in the conclusion in this case) + - a callable, that gets a span and should return a context for this span. + For instance, `lambda span: span.sent` to use the sentence as context. + attributes : AttributesArg + The attributes to predict or train on. If a dict is given, keys are the + attributes and values are the labels for which the attr is allowed, or True + if the attr is allowed for all labels. + api_params : APIParams + Additional parameters for the vLLM API call. + response_mapping : Optional[Dict[str, Any]] + A mapping from regex patterns to values that will be used to map the + responses from the model to the bindings. If not provided, the raw + responses will be used. The first matching regex will be used to map the + response to the binding. + Example: `{"^yes$": True, "^no$": False}` will map "yes" to True and "no" to + False. + timeout : float + The timeout for the vLLM API call. Default is 15.0 seconds. + n_concurrent_tasks : int + The number of concurrent tasks to run when calling the vLLM API. + Default is 4. + kwargs: Dict[str, Any] + Additional keyword arguments passed to the vLLM API call. + This can include parameters like `n` for the number of responses to generate, + or any other OpenAI API parameters. + + Authors and citation + -------------------- + The `eds.llm_qualifier` component was developed by AP-HP's Data Science team. + """ + + def __init__( + self, + nlp: Optional[Pipeline] = None, + name: str = "span_classifier", + prompt: Optional[PromptConfig] = None, + api_url: str = "http://localhost:8000/v1", + model: str = "Qwen/Qwen3-8B", + *, + attributes: AttributesArg = None, + span_getter: SpanGetterArg = None, + context_getter: Optional[SpanGetterArg] = None, + response_mapping: Optional[Dict[str, Any]] = None, + api_params: APIParams = { + "max_tokens": 50, + "temperature": 0.0, + "response_format": None, + "extra_body": None, + }, + timeout: float = 15.0, + n_concurrent_tasks: int = 4, + **kwargs, + ): + if attributes is None: + raise TypeError( + "The `attributes` parameter is required. Please provide a dict of " + "attributes to predict or train on." + ) + + span_getter = span_getter or {"ents": True} + + self.bindings: List[Tuple[str, List[str], List[Any]]] = [ + (k if k.startswith("_.") else f"_.{k}", v, []) + for k, v in attributes.items() + ] + + # Store API configuration + self.api_url = api_url + self.model = model + self.extra_body = api_params.get("extra_body") + self.temperature = api_params.get("temperature") + self.max_tokens = api_params.get("max_tokens") + self.response_format = api_params.get("response_format") + self.response_mapping = response_mapping + self.kwargs = kwargs.get("kwargs") or {} + self.timeout = timeout + self.n_concurrent_tasks = n_concurrent_tasks + + # Prompt config + prompt = prompt or {} + self.prompt = prompt + self.system_prompt = prompt.get("system_prompt") + self.user_prompt = prompt.get("user_prompt") + self.prefix_prompt = prompt.get("prefix_prompt") + self.suffix_prompt = prompt.get("suffix_prompt") + self.examples = prompt.get("examples") + + super().__init__(nlp, name, span_getter=span_getter) + self.context_getter = context_getter + + if self.response_mapping: + self.get_response_mapping_regex_dict() + + @property + def attributes(self) -> Attributes: + return {qlf: labels for qlf, labels, _ in self.bindings} + + def set_extensions(self): + super().set_extensions() + for group in self.bindings: + qlf = group[0] + if qlf.startswith("_."): + qlf = qlf[2:] + if not Span.has_extension(qlf): + Span.set_extension(qlf, default=None) + + def preprocess(self, doc: Doc, **kwargs) -> Dict[str, Any]: + spans = list(get_spans(doc, self.span_getter)) + spans_text = [span.text for span in spans] + if self.context_getter is None or not callable(self.context_getter): + contexts = list(get_spans(doc, self.context_getter)) + else: + contexts = [self.context_getter(span) for span in spans] + + contexts_text = [context.text for context in contexts] + + doc_batch_messages = [] + for span_text, context_text in zip(spans_text, contexts_text): + if self.prefix_prompt: + final_user_prompt = ( + self.prefix_prompt.format(span=span_text) + context_text + ) + else: + final_user_prompt = context_text + if self.suffix_prompt: + final_user_prompt += self.suffix_prompt + + messages = create_prompt_messages( + system_prompt=self.system_prompt, + user_prompt=self.user_prompt, + examples=self.examples, + final_user_prompt=final_user_prompt, + ) + doc_batch_messages.append(messages) + + return { + "$spans": spans, + "spans_text": spans_text, + "contexts": contexts, + "contexts_text": contexts_text, + "doc_batch_messages": doc_batch_messages, + } + + def collate(self, batch: Dict[str, Sequence[Any]]) -> LLMSpanClassifierBatchInput: + collated = { + "batch_messages": [ + message for item in batch for message in item["doc_batch_messages"] + ] + } + + return collated + + # noinspection SpellCheckingInspection + def forward( + self, + batch: LLMSpanClassifierBatchInput, + ) -> Dict[str, List[Any]]: + """ + Apply the span classifier module to the document embeddings and given spans to: + - compute the loss + - and/or predict the labels of spans + + Parameters + ---------- + batch: SpanClassifierBatchInput + The input batch + + Returns + ------- + BatchOutput + """ + + # Here call the LLM API + llm = AsyncLLM( + model_name=self.model, + api_url=self.api_url, + extra_body=self.extra_body, + temperature=self.temperature, + max_tokens=self.max_tokens, + response_format=self.response_format, + timeout=self.timeout, + n_concurrent_tasks=self.n_concurrent_tasks, + **self.kwargs, + ) + pred = run_async(llm(batch_messages=batch["batch_messages"])) + + return { + "labels": pred, + } + + def get_response_mapping_regex_dict(self) -> Dict[str, str]: + self.response_mapping_regex = { + re.compile(regex): mapping_value + for regex, mapping_value in self.response_mapping.items() + } + return self.response_mapping_regex + + def map_response(self, value: str) -> str: + for ( + compiled_regex, + mapping_value, + ) in self.response_mapping_regex.items(): + if compiled_regex.search(value): + mapped_value = mapping_value + break + else: + mapped_value = None + return mapped_value + + def postprocess( + self, + docs: Sequence[Doc], + results: LLMSpanClassifierBatchOutput, + inputs: List[Dict[str, Any]], + ) -> Sequence[Doc]: + # Preprocessed docs should still be in the cache + spans = [span for sample in inputs for span in sample["$spans"]] + all_labels = results["labels"] + # For each prediction group (exclusive bindings)... + + for qlf, labels, _ in self.bindings: + for value, span in zip(all_labels, spans): + if labels is True or span.label_ in labels: + if value is None: + mapped_value = None + elif self.response_mapping is not None: + # ...assign the mapped value to the span + mapped_value = self.map_response(value) + else: + mapped_value = value + BINDING_SETTERS[qlf](span, mapped_value) + + return docs + + def batch_process(self, docs): + inputs = [self.preprocess(doc) for doc in docs] + collated = self.collate(inputs) + res = self.forward(collated) + docs = self.postprocess(docs, res, inputs) + + return docs + + def enable_cache(self, cache_id=None): + # For compatibility + pass + + def disable_cache(self, cache_id=None): + # For compatibility + pass diff --git a/edsnlp/pipes/qualifiers/llm/llm_utils.py b/edsnlp/pipes/qualifiers/llm/llm_utils.py new file mode 100644 index 0000000000..2e834f99f3 --- /dev/null +++ b/edsnlp/pipes/qualifiers/llm/llm_utils.py @@ -0,0 +1,387 @@ +import asyncio +import json +import logging +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, +) + +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import ChatCompletion + +logger = logging.getLogger(__name__) + + +class AsyncLLM: + """ + AsyncLLM is an asynchronous interface for interacting with Large Language Models + (LLMs) via an API, supporting concurrent requests and batch processing. + """ + + def __init__( + self, + model_name="Qwen/Qwen3-8B", + api_url: str = "http://localhost:8000/v1", + temperature: float = 0.0, + max_tokens: int = 50, + extra_body: Optional[Dict[str, Any]] = None, + response_format: Optional[Dict[str, Any]] = None, + api_key: str = "EMPTY_API_KEY", + n_completions: int = 1, + timeout: float = 15.0, + n_concurrent_tasks: int = 4, + **kwargs, + ): + """ + Initializes the AsyncLLM class with configuration parameters for interacting + with an OpenAI-compatible API server. + + Parameters + ---------- + model_name : str, optional + Name of the model to use (default: "Qwen/Qwen3-8B"). + api_url : str, optional + Base URL of the API server (default: "http://localhost:8000/v1"). + temperature : float, optional + Sampling temperature for generation (default: 0.0). + max_tokens : int, optional + Maximum number of tokens to generate per completion (default: 50). + extra_body : Optional[Dict[str, Any]], optional + Additional parameters to include in the API request body (default: None). + response_format : Optional[Dict[str, Any]], optional + Format specification for the API response (default: None). + api_key : str, optional + API key for authentication (default: "EMPTY_API_KEY"). + n_completions : int, optional + Number of completions to request per prompt (default: 1). + timeout : float, optional + Timeout for API requests in seconds (default: 15.0). + n_concurrent_tasks : int, optional + Maximum number of concurrent tasks for API requests (default: 4). + **kwargs + Additional keyword arguments for further customization. + """ + + # Set OpenAI's API key and API base to use vLLM's API server. + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + self.extra_body = extra_body + self.response_format = response_format + self.n_completions = n_completions + self.timeout = timeout + self.kwargs = kwargs + self.n_concurrent_tasks = n_concurrent_tasks + self.responses = [] + self._lock = None + + self.client = AsyncOpenAI( + api_key=api_key, + base_url=api_url, + default_headers={"Connection": "close"}, + ) + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit - properly close the client.""" + await self.aclose() + + async def aclose(self): + """Properly close the AsyncOpenAI client to prevent resource leaks.""" + if hasattr(self, "client") and self.client is not None: + await self.client.close() + + @property + def lock(self): + if self._lock is None: + self._lock = asyncio.Lock() + return self._lock + + async def async_id_message_generator( + self, batch_messages: List[List[Dict[str, str]]] + ) -> AsyncIterator[Tuple[int, List[Dict[str, str]]]]: + """ + Generator + """ + for i, messages in enumerate(batch_messages): + yield (i, messages) + + def parse_messages( + self, response: ChatCompletion, response_format: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Parse the response from the LLM and return the content. + """ + if (response_format is not None) and (isinstance(response, ChatCompletion)): + prediction = [ + parse_json_response( + choice.message.content, response_format=response_format + ) + for choice in response.choices + ] + if self.n_completions == 1: + prediction = prediction[0] + + return prediction + + else: + return response + + async def call_llm( + self, id: int, messages: List[Dict[str, str]] + ) -> Tuple[int, ChatCompletion]: + """ + Call the LLM with the given messages and return the response. + + Parameters + ---------- + id : int + Unique identifier for the call + messages : List[Dict[str, str]] + List of messages to send to the LLM, where each message is a dictionary + with keys 'role' and 'content'. + + Returns + ------- + Tuple[int, ChatCompletion] + The id of the call and the ChatCompletion object corresponding to the + LLM response + """ + + raw_response = await asyncio.wait_for( + self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=self.max_tokens, + n=self.n_completions, + temperature=self.temperature, + stream=False, + response_format=self.response_format, + extra_body=self.extra_body, + **self.kwargs, + ), + timeout=self.timeout, + ) + + # Parse the response + parsed_response = self.parse_messages(raw_response, self.response_format) + + return id, parsed_response + + def store_responses(self, p_id, abbreviation_list): + """ """ + self.responses.append((p_id, abbreviation_list)) + + async def async_worker( + self, + name: str, + id_messages_tuples: AsyncIterator[Tuple[int, List[List[Dict[str, str]]]]], + ): + while True: + try: + ( + idx, + message, + ) = await anext(id_messages_tuples) # noqa: F821 + idx, response = await self.call_llm(idx, message) + + logger.info(f"Worker {name} has finished process {idx}") + except StopAsyncIteration: + # Everything has been parsed! + logger.info( + f"[{name}] Received StopAsyncIteration, worker will shutdown" + ) + break + except TimeoutError as e: + logger.error(f"[{name}] TimeoutError on chunk {idx}\n{e}") + logger.error(f"Timeout was set to {self.timeout} seconds") + if self.n_completions == 1: + response = "" + else: + response = [""] * self.n_completions + except BaseException as e: + logger.error( + f"[{name}] Exception raised on chunk {idx}\n{e}" + ) # type(e) + if self.n_completions == 1: + response = "" + else: + response = [""] * self.n_completions + async with self.lock: + self.store_responses( + idx, + response, + ) + + def sort_responses(self): + sorted_responses = [] + for i, output in sorted(self.responses, key=lambda x: x[0]): + if isinstance(output, ChatCompletion): + if self.n_completions == 1: + sorted_responses.append(output.choices[0].message.content) + else: + sorted_responses.append( + [choice.message.content for choice in output.choices] + ) + else: + sorted_responses.append(output) + + return sorted_responses + + def clean_storage(self): + del self.responses + self.responses = [] + + async def __call__(self, batch_messages: List[List[Dict[str, str]]]): + """ + Asynchronous coroutine, it should be called using the + `edsnlp.utils.asynchronous.run_async` function. + + Parameters + ---------- + batch_messages : List[List[Dict[str, str]]] + List of message batches to send to the LLM, where each batch is a list + of dictionaries with keys 'role' and 'content'. + """ + try: + # Shared prompt generator + id_messages_tuples = self.async_id_message_generator(batch_messages) + + # n concurrent tasks + tasks = { + asyncio.create_task( + self.async_worker(f"Worker-{i}", id_messages_tuples) + ) + for i in range(self.n_concurrent_tasks) + } + + await asyncio.gather(*tasks) + tasks.clear() + predictions = self.sort_responses() + self.clean_storage() + + return predictions + except Exception: + # Ensure cleanup even if an exception occurs + await self.aclose() + raise + + +def create_prompt_messages( + system_prompt: Optional[str] = None, + user_prompt: Optional[str] = None, + examples: Optional[List[Tuple[str, str]]] = None, + final_user_prompt: Optional[str] = None, +) -> List[Dict[str, str]]: + """ + Create a list of prompt messages formatted for use with a language model (LLM) API. + + system_prompt : Optional[str], default=None + The initial system prompt to set the behavior or context for the LLM. + user_prompt : Optional[str], default=None + An initial user prompt to provide context or instructions to the LLM. + examples : Optional[List[Tuple[str, str]]], default=None + A list of example (prompt, response) pairs to guide the LLM's behavior. + final_user_prompt : Optional[str], default=None + The final user prompt to be appended at the end of the message sequence. + + Returns + ------- + List[Dict[str, str]] + A list of message dictionaries, each containing a 'role' + (e.g., 'system', 'user', 'assistant') + and corresponding 'content', formatted for LLM input. + + """ + + messages = [] + if system_prompt: + messages.append( + { + "role": "system", + "content": system_prompt, + } + ) + if user_prompt: + messages.append( + { + "role": "user", + "content": user_prompt, + } + ) + if examples: + for prompt, response in examples: + messages.append( + { + "role": "user", + "content": prompt, + } + ) + messages.append( + { + "role": "assistant", + "content": response, + } + ) + if final_user_prompt: + messages.append( + { + "role": "user", + "content": final_user_prompt, + } + ) + + return messages + + +def parse_json_response( + response: str, + response_format: Optional[Dict[str, Any]] = None, + errors: str = "ignore", +) -> Dict[str, Any]: + """ + Parses a response string as JSON if a JSON schema format is specified, + otherwise returns the raw response. + + Parameters + ---------- + response : str + The response string to parse. + response_format : Optional[Dict[str, Any]], optional + A dictionary specifying the expected response format. + If it contains {"type": "json_schema"}, the response will be parsed as JSON. + Defaults to None. + errors : str, optional + Determines error handling behavior when JSON decoding fails. + If set to "ignore", returns an empty dictionary on failure. + Otherwise, returns the raw response. Defaults to "ignore". + + Returns + ------- + Dict[str, Any] + The parsed JSON object if parsing is successful and a JSON schema is specified. + If parsing fails and errors is "ignore", returns an empty dictionary. + If parsing fails and errors is not "ignore", returns the raw response string. + If no response format is specified, returns the raw response string. + """ + if response is None: + return {} + + if (response_format is not None) and (response_format.get("type") == "json_schema"): + try: + return json.loads(response.strip()) + except json.JSONDecodeError: + if errors == "ignore": + return {} + else: + return response + else: + # If no response format is specified, return the raw response + return response diff --git a/edsnlp/utils/asynchronous.py b/edsnlp/utils/asynchronous.py new file mode 100644 index 0000000000..58cadb39ec --- /dev/null +++ b/edsnlp/utils/asynchronous.py @@ -0,0 +1,37 @@ +import asyncio +from typing import Any, Coroutine, Optional, TypeVar + +T = TypeVar("T") + + +def run_async(coro: Coroutine[Any, Any, T]) -> T: + """ + Runs an asynchronous coroutine and always waits for the result, + whether or not an event loop is already running. + + In a standard Python script (no active event loop), it uses `asyncio.run()`. + In a notebook or environment with a running event loop, it applies a patch + using `nest_asyncio` and runs the coroutine via `loop.run_until_complete`. + + Parameters + ---------- + coro : Coroutine + The coroutine to run. + + Returns + ------- + T + The result returned by the coroutine. + """ + try: + loop: Optional[asyncio.AbstractEventLoop] = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): # pragma: no cover + import nest_asyncio + + nest_asyncio.apply() + return asyncio.get_running_loop().run_until_complete(coro) + else: + return asyncio.run(coro) diff --git a/mkdocs.yml b/mkdocs.yml index 1271f334f5..0802cbbd88 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -56,6 +56,7 @@ nav: - tutorials/training-ner.md - tutorials/training-span-classifier.md - tutorials/tuning.md + - tutorials/qualifying-entities-with-llm.md - Pipes: - Overview: pipes/index.md - Core Pipelines: @@ -73,6 +74,7 @@ nav: - pipes/qualifiers/hypothesis.md - pipes/qualifiers/reported-speech.md - pipes/qualifiers/history.md + - 'LLM Classifier': pipes/qualifiers/llm-qualifier.md - Miscellaneous: - pipes/misc/index.md - pipes/misc/dates.md diff --git a/pyproject.toml b/pyproject.toml index ba2a525beb..61c82390d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,6 @@ dev-no-ml = [ "edsnlp[docs-no-ml]", ] docs-no-ml = [ - "mkdocs-eds @ git+https://github.com/percevalw/mkdocs-eds.git@main#egg=mkdocs-eds ; python_version>='3.9'", "markdown-grid-tables==0.4.0; python_version>='3.9'", ] ml = [ @@ -76,9 +75,15 @@ ml = [ "transformers>=4.0.0", "accelerate>=0.20.3", ] +llm = [ + "openai>=1.90.0; python_version>='3.8'", + "nest_asyncio", + "respx", +] docs = [ "edsnlp[docs-no-ml]", "edsnlp[ml]", + "edsnlp[llm]", ] dev = [ "edsnlp[dev-no-ml]", @@ -260,6 +265,9 @@ where = ["."] "eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component" "eds.biaffine_dep_parser" = "edsnlp.pipes.trainable.biaffine_dep_parser.factory:create_component" +# LLM (optional dependencies) +"eds.llm_span_qualifier" = "edsnlp.pipes.qualifiers.llm.factory:create_component" + [project.entry-points."edsnlp_schedules"] "linear" = "edsnlp.training.optimizer:LinearSchedule" @@ -412,6 +420,7 @@ omit-covered-files = false # Some tests may download large objects such as the UMLS # This timeout is mostly here to kill the CI in the case of deadlocks or infinite loops timeout = 600 +asyncio_mode = "auto" [tool.coverage.report] precision = 2 diff --git a/tests/data/test_polars.py b/tests/data/test_polars.py index d6b706288d..869d02ce09 100644 --- a/tests/data/test_polars.py +++ b/tests/data/test_polars.py @@ -49,6 +49,7 @@ def test_read_shuffle_loop(num_cpu_workers: int): .set_processing(num_cpu_workers=num_cpu_workers) for _ in range(2) ) + # This test differs from other data rand perm test as polars rng has changed # between versions (from 1.32 ?) so it's easier to check this notes_a = list(islice(notes_a, 6)) diff --git a/tests/pipelines/qualifiers/test_llm_qualifier.py b/tests/pipelines/qualifiers/test_llm_qualifier.py new file mode 100644 index 0000000000..578c851ac9 --- /dev/null +++ b/tests/pipelines/qualifiers/test_llm_qualifier.py @@ -0,0 +1,182 @@ +from pytest import mark + +import edsnlp +from edsnlp.pipes.qualifiers.llm.llm_qualifier import LLMSpanClassifier +from edsnlp.utils.examples import parse_example +from edsnlp.utils.span_getters import make_span_context_getter + + +@mark.parametrize("label", ["True", None]) +@mark.parametrize("response_mapping", [{"^True$": "1", "^False$": "0"}, None]) +def test_llm_span_classifier(label, response_mapping): + # Patch AsyncLLM to avoid real API calls + class DummyAsyncLLM: + def __init__(self, *args, **kwargs): + # Initialize the dummy LLM + pass + + async def __call__(self, batch_messages): + # Return a dummy label for each message + return [label for _ in batch_messages] + + import edsnlp.pipes.qualifiers.llm.llm_qualifier as llm_mod + + llm_mod.AsyncLLM = DummyAsyncLLM + + nlp = edsnlp.blank("eds") + example = "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 + text, entities = parse_example(example) + doc = nlp(text) + doc.ents = [ + doc.char_span(ent.start_char, ent.end_char, label="date") for ent in entities + ] + + # LLMSpanClassifier + nlp.add_pipe( + LLMSpanClassifier( + nlp=nlp, + name="llm", + model="dummy", + span_getter={"ents": True}, + attributes={"_.test_attr": True}, + context_getter=make_span_context_getter( + context_sents=0, + context_words=(5, 5), + ), + prompt={ + "system_prompt": "You are a medical assistant.", + "user_prompt": "You should help us identify dates in the text.", + "prefix_prompt": "Is '{span}' a date? The text is as follows:\n<<< ", + "suffix_prompt": " >>>", + "examples": [ + ( + "\nIs'07/12/2020' a date. The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales. >>>", # noqa: E501 + "False", + ) + ], + }, + api_url="https://dummy", + api_params={ + "max_tokens": 10, + "temperature": 0.0, + "response_format": None, + "extra_body": None, + }, + response_mapping=response_mapping, + n_concurrent_tasks=1, + ) + ) + doc = nlp(doc) + + # Check that the extension is set and the dummy label is applied + for span in doc.ents: + assert hasattr(span._, "test_attr") + if response_mapping is not None: + if label == "True": + assert span._.test_attr == "1" + elif label is None: + assert span._.test_attr is None + + if response_mapping is None: + if label == "True": + assert span._.test_attr == label + elif label is None: + assert span._.test_attr is None + + assert nlp.get_pipe("llm").attributes == {"_.test_attr": True} + + +@mark.parametrize( + "prefix_prompt,suffix_prompt", + [("Is '{span}' a date? The text is as follows:\n<<< ", " >>>"), (None, None)], +) +def test_llm_span_classifier_preprocess(prefix_prompt, suffix_prompt): + # Patch AsyncLLM to avoid real API calls + class DummyAsyncLLM: + def __init__(self, *args, **kwargs): + # Initialize the dummy LLM + pass + + async def __call__(self, batch_messages): + # Return a dummy label for each message + return ["True" for _ in batch_messages] + + import edsnlp.pipes.qualifiers.llm.llm_qualifier as llm_mod + + llm_mod.AsyncLLM = DummyAsyncLLM + + nlp = edsnlp.blank("eds") + example = "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 + text, entities = parse_example(example) + doc = nlp(text) + doc.ents = [ + doc.char_span(ent.start_char, ent.end_char, label="date") for ent in entities + ] + + system_prompt = "You are a medical assistant." + user_prompt = "You should help us identify dates in the text." + examples = [ + ( + "\nIs'07/12/2020' a date. The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales. >>>", # noqa: E501 + "False", + ) + ] + + # LLMSpanClassifier + llm = LLMSpanClassifier( + nlp=nlp, + name="llm", + model="dummy", + span_getter={"ents": True}, + attributes={"_.test_attr": True}, + context_getter=make_span_context_getter( + context_sents=0, + context_words=(5, 5), + ), + prompt={ + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "prefix_prompt": prefix_prompt, + "suffix_prompt": suffix_prompt, + "examples": examples, + }, + api_url="https://dummy", + api_params={ + "max_tokens": 10, + "temperature": 0.0, + "response_format": None, + "extra_body": None, + }, + response_mapping=None, + n_concurrent_tasks=1, + ) + + inputs = llm.preprocess(doc) + if (prefix_prompt is not None) and (suffix_prompt is not None): + assert inputs["doc_batch_messages"][0] == [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + { + "role": "user", + "content": examples[0][0], + }, + {"role": "assistant", "content": examples[0][1]}, + { + "role": "user", + "content": "Is '20/02/2025' a date? The text is as follows:\n<<< En RCP du 20/02/2025, patient classé cT3 >>>", # noqa: E501 + }, + ] + else: + assert inputs["doc_batch_messages"][0] == [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + { + "role": "user", + "content": examples[0][0], + }, + {"role": "assistant", "content": examples[0][1]}, + { + "role": "user", + "content": "En RCP du 20/02/2025, patient classé cT3", # noqa: E501 + }, + ] diff --git a/tests/pipelines/qualifiers/test_llm_utils.py b/tests/pipelines/qualifiers/test_llm_utils.py new file mode 100644 index 0000000000..b9d936e587 --- /dev/null +++ b/tests/pipelines/qualifiers/test_llm_utils.py @@ -0,0 +1,213 @@ +from typing import Optional + +import httpx +import respx +from openai.types.chat.chat_completion import ChatCompletion +from pytest import mark + +from edsnlp.pipes.qualifiers.llm.llm_utils import ( + AsyncLLM, + create_prompt_messages, + parse_json_response, +) +from edsnlp.utils.asynchronous import run_async + + +@mark.parametrize("n_concurrent_tasks", [1, 2]) +def test_async_llm(n_concurrent_tasks): + api_url = "http://localhost:8000/v1/" + suffix_url = "chat/completions" + llm_api = AsyncLLM(n_concurrent_tasks=n_concurrent_tasks, api_url=api_url) + + with respx.mock: + respx.post(api_url + suffix_url).mock( + side_effect=[ + httpx.Response( + 200, json={"choices": [{"message": {"content": "positive"}}]} + ), + httpx.Response( + 200, json={"choices": [{"message": {"content": "negative"}}]} + ), + ] + ) + + response = run_async( + llm_api( + batch_messages=[ + [ + {"role": "user", "content": "your prompt here"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "your second prompt here"}, + ], + [{"role": "user", "content": "your second prompt here"}], + ] + ) + ) + assert response == ["positive", "negative"] + + +def test_create_prompt_messages(): + messages = create_prompt_messages( + system_prompt="Hello", + user_prompt="Hi", + examples=[("One", "1")], + final_user_prompt="What is your name?", + ) + messages_expected = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi"}, + {"role": "user", "content": "One"}, + {"role": "assistant", "content": "1"}, + {"role": "user", "content": "What is your name?"}, + ] + assert messages == messages_expected + messages2 = create_prompt_messages( + system_prompt="Hello", + user_prompt=None, + examples=[("One", "1")], + final_user_prompt="What is your name?", + ) + messages_expected2 = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "One"}, + {"role": "assistant", "content": "1"}, + {"role": "user", "content": "What is your name?"}, + ] + assert messages2 == messages_expected2 + + +def create_fake_chat_completion( + choices: int = 1, content: Optional[str] = '{"biopsy":false}' +): + fake_response_data = { + "id": "chatcmpl-fake123", + "object": "chat.completion", + "created": 1699999999, + "model": "toto", + "choices": [ + { + "index": i, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + for i in range(choices) + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 15, "total_tokens": 25}, + } + + # Create the ChatCompletion object + fake_completion = ChatCompletion.model_validate(fake_response_data) + return fake_completion + + +def test_parse_json_response(): + response = create_fake_chat_completion() + response_format = { + "type": "json_schema", + "json_schema": { + "name": "DateModel", + "schema": { + "properties": {"biopsy": {"title": "Biopsy", "type": "boolean"}}, + "required": ["biopsy"], + "title": "DateModel", + "type": "object", + }, + }, + } + + llm = AsyncLLM(n_concurrent_tasks=1) + parsed_response = llm.parse_messages(response, response_format) + assert parsed_response == {"biopsy": False} + + response = create_fake_chat_completion(content=None) + parsed_response = llm.parse_messages(response, response_format=response_format) + assert parsed_response == {} + + parsed_response = llm.parse_messages(response, response_format=None) + assert parsed_response == response + + parsed_response = llm.parse_messages(None, response_format=None) + assert parsed_response is None + + +@mark.parametrize("n_completions", [1, 2]) +def test_exception_handling(n_completions): + api_url = "http://localhost:8000/v1/" + suffix_url = "chat/completions" + llm_api = AsyncLLM( + n_concurrent_tasks=1, api_url=api_url, n_completions=n_completions + ) + + with respx.mock: + respx.post(api_url + suffix_url).mock( + side_effect=[ + httpx.Response(404, json={"choices": [{}]}), + ] + ) + + response = run_async( + llm_api( + batch_messages=[ + [{"role": "user", "content": "your prompt here"}], + ] + ) + ) + if n_completions == 1: + assert response == [""] + else: + assert response == [[""] * n_completions] + + +@mark.parametrize("errors", ["ignore", "raw"]) +def test_json_decode_error(errors): + raw_response = '{"biopsy";false}' + response_format = { + "type": "json_schema", + "json_schema": { + "name": "DateModel", + "schema": { + "properties": {"biopsy": {"title": "Biopsy", "type": "boolean"}}, + "required": ["biopsy"], + "title": "DateModel", + "type": "object", + }, + }, + } + + response = parse_json_response(raw_response, response_format, errors=errors) + if errors == "ignore": + assert response == {} + else: + assert response == raw_response + + +def test_decode_no_format(): + raw_response = '{"biopsy":false}' + + response = parse_json_response(raw_response, response_format=None) + + assert response == raw_response + + +def test_multiple_completions(n_completions=2): + api_url = "http://localhost:8000/v1/" + suffix_url = "chat/completions" + llm_api = AsyncLLM( + n_concurrent_tasks=1, api_url=api_url, n_completions=n_completions + ) + completion = create_fake_chat_completion(n_completions, content="false") + with respx.mock: + respx.post(api_url + suffix_url).mock( + side_effect=[ + httpx.Response(200, json=completion.model_dump()), + ] + ) + + response = run_async( + llm_api( + batch_messages=[ + [{"role": "user", "content": "your prompt here"}], + ] + ) + ) + assert response == [["false", "false"]] diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 9500b46407..d7f619c1b1 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1,3 +1,5 @@ +import sys + import pytest import edsnlp @@ -13,6 +15,14 @@ def test_pipelines(doc): assert not doc[0]._.history +def is_openai_3_7(e): + return ( + "openai" in str(e) + and sys.version_info.major == 3 + and sys.version_info.minor == 7 + ) + + def test_import_all(): import edsnlp.pipes @@ -23,6 +33,9 @@ def test_import_all(): except (ImportError, AttributeError) as e: if "torch" in str(e): pass + if is_openai_3_7(e): + # Skip tests for OpenAI using python 3.7 + pass def test_non_existing_pipe(): diff --git a/tests/test_entrypoints.py b/tests/test_entrypoints.py index 04df6812e4..8d4eb77b1d 100644 --- a/tests/test_entrypoints.py +++ b/tests/test_entrypoints.py @@ -6,6 +6,7 @@ except ImportError: from importlib_metadata import entry_points +# Torch installation try: import torch.nn except ImportError: @@ -14,6 +15,15 @@ if torch is None: pytest.skip("torch not installed", allow_module_level=True) +# openai installation +try: + import openai +except ImportError: + openai = None + +if openai is None: + pytest.skip("openai not installed", allow_module_level=True) + def test_entrypoints(): ep = entry_points()