diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9b24fcf..362a2e7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - id: check-yaml - id: detect-private-key - repo: https://github.com/tox-dev/pyproject-fmt - rev: "v2.6.0" + rev: "v2.11.1" hooks: - id: pyproject-fmt - repo: https://github.com/citation-file-format/cffconvert @@ -39,12 +39,12 @@ repos: - id: yamllint exclude: pre-commit-config.yaml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.13.0" + rev: "v0.14.8" hooks: - id: ruff-format - id: ruff-check - repo: https://github.com/rhysd/actionlint - rev: v1.7.7 + rev: v1.7.9 hooks: - id: actionlint - repo: https://gitlab.com/vojko.pribudic.foss/pre-commit-update diff --git a/backend/Dockerfile b/backend/Dockerfile index e8d2d43..7a9ba46 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -25,9 +25,9 @@ COPY --exclude=.venv ./packages/manugen-ai/ /packages/manugen-ai/ ENV VIRTUAL_ENV=/app/.venv ENV UV_LINK_MODE=copy ENV UV_NO_SYNC=1 -COPY ./backend/pyproject.toml ./ +COPY ./backend/pyproject.toml ./backend/uv.lock ./ RUN --mount=type=cache,target=/root/.cache/uv \ - uv sync --cache-dir /root/.cache/uv + uv sync --frozen --cache-dir /root/.cache/uv # copy backend contents into working dir COPY --exclude=.venv ./backend/ . diff --git a/packages/manugen-ai/pytest.ini b/packages/manugen-ai/pytest.ini new file mode 100644 index 0000000..f498ea7 --- /dev/null +++ b/packages/manugen-ai/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +addopts = -m "not slow" +markers = + slow: marks tests as slow (deselected by default) diff --git a/packages/manugen-ai/src/manugen_ai/data.py b/packages/manugen-ai/src/manugen_ai/data.py index 76a6a6c..a29d682 100644 --- a/packages/manugen-ai/src/manugen_ai/data.py +++ b/packages/manugen-ai/src/manugen_ai/data.py @@ -8,6 +8,7 @@ import duckdb import numpy as np import pyarrow as pa +from duckdb.typing import VARCHAR from manugen_ai.utils import download_file_if_not_available @@ -275,9 +276,7 @@ def create_withdrarxiv_embeddings( """ ) - conn.create_function( - "embed", embed, [duckdb.typing.VARCHAR], f"FLOAT[{get_embedding_size()}]" - ) + conn.create_function("embed", embed, [VARCHAR], f"FLOAT[{get_embedding_size()}]") # Batch-compute embeddings for every abstract batch_size = 100 @@ -373,9 +372,7 @@ def search_withdrarxiv_embeddings(query: str, top_k: int = 2): ) conn = duckdb.connect(target_db) - conn.create_function( - "embed", embed, [duckdb.typing.VARCHAR], f"FLOAT[{get_embedding_size()}]" - ) + conn.create_function("embed", embed, [VARCHAR], f"FLOAT[{get_embedding_size()}]") q = {"q": query, "k": top_k} df = ( diff --git a/packages/manugen-ai/tests/test_data.py b/packages/manugen-ai/tests/test_data.py new file mode 100644 index 0000000..2040124 --- /dev/null +++ b/packages/manugen-ai/tests/test_data.py @@ -0,0 +1,80 @@ +""" +Tests for the 'data' module, mostly consisting of embedding models. +""" + +import json +from pathlib import Path + +import pytest +from manugen_ai.data import create_withdrarxiv_embeddings, search_withdrarxiv_embeddings + + +@pytest.mark.slow +def test_create_withdrarxiv_embedding(): + """ " + Test creating the withdrarxiv embeddings database using a table of retracted + papers from arXiv and an embedding model, either via Gemini's API or via + a local Flag embedding model. + + The test is marked 'slow' and thus excluded from the default test suite + for a few reasons: + - it takes ~2 minutes to run using the Gemini embeddings API, which incurs + API usage costs + - it takes a lot longer with the flag embeddings, which require + downloading a large model file, BAAI/bge-m3, from Hugging Face. + - the withdrarxiv dataset on which it relies can't be included in the repo + and must be manually downloaded; see test_create_withdrarxiv_embedding()'s + docstring for details. + + Depending on the value of the env var USE_GEMINI_EMBEDDINGS, this test will + use either Google's Gemini embeddings via Google's API (if set to "1") or + the "flag" embedding model from Hugging Face. The Gemini embeddings require + that the GOOGLE_API_KEY env var be set to a valid API key. + + The Flag embedding model uses the FLAGEMBEDDING_MODEL_OR_PATH env var + to determine where the model is already located, or if given a model + what model to download. By default it'll download "BAAI/bge-m3" from + Hugging Face. + """ + datafiles_dir = ( + Path(__file__).resolve().parent.parent / "src" / "manugen_ai" / "data" + ) + target_db = "withdrarxiv_test_embeddings.duckdb" + + try: + # construct the db; we're mostly seeing if this runs without throwing anything + resulting_db_path = create_withdrarxiv_embeddings(target_db) + + # check that the resulting db path is what we passed in + assert resulting_db_path == target_db + + # ensure the db file was actually created + full_db_path = datafiles_dir / target_db + assert full_db_path.exists() + + finally: + full_db_path = datafiles_dir / target_db + full_db_path.unlink(missing_ok=True) + + +def test_search_withdrarxiv_embeddings(): + """ + Test searching the withdrarxiv embeddings. + + As of 2025-12-03, the retraction db produced these results for the query + "What is the role of quantum entanglement in quantum computing?": + [ + {"related_retraction_reasons":"Just because interleaving bisimilarity based ACP cannot be reversed, some conclusions of this paper are wrong and cannot be remedied, I beg to withdraw this paper"}, + {"related_retraction_reasons":"a wrong formula"}, + {"related_retraction_reasons":"The paper is withdrawn because of many flaws in the manuscript"} + ] + """ + results = json.loads( + search_withdrarxiv_embeddings( + "What is the role of quantum entanglement in quantum computing?", top_k=3 + ) + ) + + assert len(results) == 3 + for result in results: + assert "related_retraction_reasons" in result