diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 56c73f65..43a5e0e7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,11 +7,11 @@ repos: - id: ruff-format name: "Ruff formatter" args: [--config=pyproject.toml] - files: '^(mellea|tests|cli|docs).*\.(py|ipynb)$' + files: '^(mellea|test|cli|docs).*\.(py|ipynb)$' - id: ruff name: "Ruff linter" args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml] - files: '^(mellea|tests).*\.(py|ipynb)$' + files: '^(mellea).*\.(py|ipynb)$' - repo: local hooks: @@ -20,7 +20,7 @@ repos: entry: uv run --no-sync mypy mellea pass_filenames: false language: system - files: '\.py$' + files: '^(mellea|test|cli|docs).*\.(py|ipynb)$' - repo: https://github.com/astral-sh/uv-pre-commit rev: 0.7.8 diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index c2a5497f..403f4912 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -10,11 +10,20 @@ from mellea.backends.formatter import TemplateFormatter from mellea.backends.huggingface import LocalHFBackend from mellea.backends.types import ModelOption -from mellea.stdlib.base import (CBlock, ChatContext, Context, ModelOutputThunk, - SimpleContext) -from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement, - Requirement, ValidationResult, - default_output_to_bool) +from mellea.stdlib.base import ( + CBlock, + ChatContext, + Context, + ModelOutputThunk, + SimpleContext, +) +from mellea.stdlib.requirement import ( + ALoraRequirement, + LLMaJRequirement, + Requirement, + ValidationResult, + default_output_to_bool, +) @pytest.fixture(scope="module") @@ -40,6 +49,7 @@ def session(backend): yield session session.reset() + @pytest.mark.qualitative def test_adapters(backend): assert len(backend._added_adapters.items()) > 0 diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index f6f10807..e49baa95 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -26,9 +26,7 @@ def backend(gh_run: int): url = url.replace("127.0.0.1", "http://localhost") return LiteLLMBackend( - model_id=_MODEL_ID, - base_url=url, - model_options={"api_base": url}, + model_id=_MODEL_ID, base_url=url, model_options={"api_base": url} ) else: return LiteLLMBackend(model_id=_MODEL_ID) @@ -111,12 +109,11 @@ def test_litellm_ollama_instruct_options(session): ModelOption.SEED: 123, ModelOption.TEMPERATURE: 0.5, ModelOption.MAX_NEW_TOKENS: 100, - - # Ollama thinking controls currently broken on Granite; see + # Ollama thinking controls currently broken on Granite; see # https://github.com/ollama/ollama/issues/10983 # TODO: Re-enable when this upstream bug gets fixed. - #ModelOption.THINKING: True, - #"reasoning_effort": True, + # ModelOption.THINKING: True, + # "reasoning_effort": True, "homer_simpson": "option should be kicked out", } @@ -144,6 +141,7 @@ def is_happy(text: str) -> bool: # should yield to true - but, of course, is model dependent assert h is True + async def test_generate_from_raw(session): prompts = [ "what is 1+1?", @@ -157,7 +155,9 @@ async def test_generate_from_raw(session): actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx ) - assert len(results) == 1, "ollama doesn't support batching; litellm should send a single message containing all prompts" + assert len(results) == 1, ( + "ollama doesn't support batching; litellm should send a single message containing all prompts" + ) assert results[0].value is not None diff --git a/test/backends/test_litellm_watsonx.py b/test/backends/test_litellm_watsonx.py index 76030f8e..352cec57 100644 --- a/test/backends/test_litellm_watsonx.py +++ b/test/backends/test_litellm_watsonx.py @@ -41,18 +41,15 @@ def test_multiple_sync_funcs(session): @pytest.mark.qualitative async def test_generate_from_raw(session): - prompts = [ - "what is 1+1?", - "what is 2+2?", - "what is 3+3?", - "what is 4+2+2?", - ] + prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+2+2?"] results = await session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx ) - assert len(results) == 1, "litellm converts a batch request for watsonx into a single message" + assert len(results) == 1, ( + "litellm converts a batch request for watsonx into a single message" + ) assert results[0].value is not None diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index 24f656bc..1848287c 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -122,6 +122,7 @@ async def test_generate_from_raw(m_session): actions=[CBlock(value=prompt) for prompt in prompts], ctx=m_session.ctx ) + # Default OpenAI implementation doesn't support structured outputs for the completions API. # def test_generate_from_raw_with_format(self): # prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] diff --git a/test/backends/test_openai_vllm/test_openai_vllm.py b/test/backends/test_openai_vllm/test_openai_vllm.py index 537392a4..30f17a86 100644 --- a/test/backends/test_openai_vllm/test_openai_vllm.py +++ b/test/backends/test_openai_vllm/test_openai_vllm.py @@ -11,8 +11,12 @@ from mellea.backends.openai import OpenAIBackend from mellea.backends.types import ModelOption, _ServerType from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk -from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement, - Requirement, req) +from mellea.stdlib.requirement import ( + ALoraRequirement, + LLMaJRequirement, + Requirement, + req, +) # The vllm tests are disabled by default, because we need a test environment with the vLLM server running. # We use an env var VLLM_TESTS_ENABLED to enable these tests. @@ -138,8 +142,11 @@ class TestOpenAIALoraStuff: base_url="http://localhost:8000/v1", api_key="EMPTY", ) - backend.add_adapter(GraniteCommonAdapter("requirement_check", - base_model_name=backend.base_model_name)) + backend.add_adapter( + GraniteCommonAdapter( + "requirement_check", base_model_name=backend.base_model_name + ) + ) m = MelleaSession(backend, ctx=ChatContext()) diff --git a/test/stdlib_basics/test_base.py b/test/stdlib_basics/test_base.py index 917619e0..e19c6adc 100644 --- a/test/stdlib_basics/test_base.py +++ b/test/stdlib_basics/test_base.py @@ -26,5 +26,6 @@ def format_for_llm(self) -> str: c = _ClosuredComponent() assert len(c.parts()) == 0 + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib_basics/test_chat.py b/test/stdlib_basics/test_chat.py index 3ae911b9..819b2796 100644 --- a/test/stdlib_basics/test_chat.py +++ b/test/stdlib_basics/test_chat.py @@ -3,6 +3,7 @@ from mellea.stdlib.base import Document from mellea.stdlib.chat import Message + def test_message_with_docs(): doc = Document("I'm text!", "Im a title!") msg = Message("user", "hello", documents=[doc]) diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index 78ff44d9..e7e0bfb3 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -5,22 +5,27 @@ from mellea.backends.model_ids import META_LLAMA_3_2_1B from mellea.backends.ollama import OllamaModelBackend from mellea.stdlib.base import ChatContext, Context -from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot, PreconditionException, SyncGenerativeSlot +from mellea.stdlib.genslot import ( + AsyncGenerativeSlot, + GenerativeSlot, + PreconditionException, + SyncGenerativeSlot, +) from mellea.stdlib.requirement import Requirement, simple_validate from mellea.stdlib.sampling.base import RejectionSamplingStrategy from mellea.stdlib.session import MelleaSession + @pytest.fixture(scope="module") def backend(gh_run: int): """Shared backend.""" if gh_run == 1: return OllamaModelBackend( - model_id=META_LLAMA_3_2_1B.ollama_name, # type: ignore + model_id=META_LLAMA_3_2_1B.ollama_name # type: ignore ) else: - return OllamaModelBackend( - model_id="granite3.3:8b", - ) + return OllamaModelBackend(model_id="granite3.3:8b") + @generative def classify_sentiment(text: str) -> Literal["positive", "negative"]: ... @@ -81,26 +86,66 @@ async def test_async_gen_slot(session): r1 = async_write_short_sentence(session, topic="cats") r2 = async_write_short_sentence(session, topic="dogs") - r3, c3 = await async_write_short_sentence(context=session.ctx, backend=session.backend, topic="fish") + r3, c3 = await async_write_short_sentence( + context=session.ctx, backend=session.backend, topic="fish" + ) results = await asyncio.gather(r1, r2) assert isinstance(r3, str) assert isinstance(c3, Context) assert len(results) == 2 + @pytest.mark.parametrize( "arg_choices,kwarg_choices,errs", [ pytest.param(["m"], ["func1", "func2", "func3"], False, id="session"), pytest.param(["context"], ["backend"], False, id="context and backend"), - pytest.param(["backend"], ["func1", "func2", "func3"], True, id="backend without context"), + pytest.param( + ["backend"], ["func1", "func2", "func3"], True, id="backend without context" + ), pytest.param(["m"], ["m"], True, id="duplicate arg and kwarg"), - pytest.param(["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], [], True, id="original func args as positional args"), - pytest.param([], ["m", "func1", "func2", "func3"], False, id="session and func as kwargs"), - pytest.param([], ["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], False, id="all kwargs"), - pytest.param([], ["func1", "m", "func2", "requirements", "func3"], False, id="interspersed kwargs"), - pytest.param([], [], True, id="missing required args") - ] + pytest.param( + [ + "m", + "precondition_requirements", + "requirements", + "strategy", + "model_options", + "func1", + "func2", + "func3", + ], + [], + True, + id="original func args as positional args", + ), + pytest.param( + [], ["m", "func1", "func2", "func3"], False, id="session and func as kwargs" + ), + pytest.param( + [], + [ + "m", + "precondition_requirements", + "requirements", + "strategy", + "model_options", + "func1", + "func2", + "func3", + ], + False, + id="all kwargs", + ), + pytest.param( + [], + ["func1", "m", "func2", "requirements", "func3"], + False, + id="interspersed kwargs", + ), + pytest.param([], [], True, id="missing required args"), + ], ) def test_arg_extraction(backend, arg_choices, kwarg_choices, errs): """Tests the internal extract_args_and_kwargs function. @@ -156,17 +201,19 @@ def test_arg_extraction(backend, arg_choices, kwarg_choices, errs): except Exception as e: found_err = True err = e - + if errs: assert found_err, "expected an exception and got none" else: assert not found_err, f"got unexpected err: {err}" + def test_disallowed_parameter_names(): with pytest.raises(ValueError): + @generative - def test(backend): - ... + def test(backend): ... + def test_precondition_failure(session): with pytest.raises(PreconditionException): @@ -174,17 +221,20 @@ def test_precondition_failure(session): m=session, text="hello", precondition_requirements=[ - Requirement("forced failure", validation_fn=simple_validate(lambda x: (False, ""))) - ] + Requirement( + "forced failure", + validation_fn=simple_validate(lambda x: (False, "")), + ) + ], ) + def test_requirement(session): classify_sentiment( - m=session, - text="hello", - requirements=["req1", "req2", Requirement("req3")] + m=session, text="hello", requirements=["req1", "req2", Requirement("req3")] ) + def test_with_no_args(session): @generative def generate_text() -> str: @@ -193,5 +243,6 @@ def generate_text() -> str: generate_text(m=session) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib_basics/test_reqlib_tools.py b/test/stdlib_basics/test_reqlib_tools.py index b7f771f2..9e92d890 100644 --- a/test/stdlib_basics/test_reqlib_tools.py +++ b/test/stdlib_basics/test_reqlib_tools.py @@ -1,10 +1,12 @@ import pytest from mellea.stdlib.reqlib.tools import _name2str + def test_name2str(): """Test handling when no Python code is present.""" + def test123(): pass + assert _name2str(test123) == "test123" assert _name2str("test1234") == "test1234" - diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index a1722e83..6694246c 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -135,12 +135,11 @@ def test_session_copy_with_context_ops(m_session): class TestPowerup: - def hello(m:MelleaSession): + def hello(m: MelleaSession): return "hello" def test_powerup(m_session): - MelleaSession.powerup(TestPowerup) assert "hello" == m_session.hello() diff --git a/test/stdlib_intrinsics/test_rag/test_rag.py b/test/stdlib_intrinsics/test_rag/test_rag.py index 18985f02..e2477cf3 100644 --- a/test/stdlib_intrinsics/test_rag/test_rag.py +++ b/test/stdlib_intrinsics/test_rag/test_rag.py @@ -23,7 +23,7 @@ @pytest.fixture(name="backend") def _backend(): """Backend used by the tests in this file.""" - + # Prevent thrashing if the default device is CPU torch.set_num_threads(4) @@ -57,12 +57,7 @@ def _read_input_json(file_name: str): documents = [] if "extra_body" in json_data and "documents" in json_data["extra_body"]: for d in json_data["extra_body"]["documents"]: - documents.append( - Document( - text=d["text"], - doc_id=d["doc_id"], - ) - ) + documents.append(Document(text=d["text"], doc_id=d["doc_id"])) return context, next_user_turn, documents @@ -177,9 +172,11 @@ def test_answer_relevance(backend): assert result == expected_rewrite # Canned input always gets rewritten. Set threshold to disable the rewrite. - result = rag.rewrite_answer_for_relevance(answer, docs, context, backend, - rewrite_threshold=0.0) + result = rag.rewrite_answer_for_relevance( + answer, docs, context, backend, rewrite_threshold=0.0 + ) assert result == answer + if __name__ == "__main__": pytest.main([__file__])