Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ repos:
- id: ruff-format
name: "Ruff formatter"
args: [--config=pyproject.toml]
files: '^(mellea|tests|cli|docs).*\.(py|ipynb)$'
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'
- id: ruff
name: "Ruff linter"
args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml]
files: '^(mellea|tests).*\.(py|ipynb)$'
files: '^(mellea).*\.(py|ipynb)$'

- repo: local
hooks:
Expand All @@ -20,7 +20,7 @@ repos:
entry: uv run --no-sync mypy mellea
pass_filenames: false
language: system
files: '\.py$'
files: '^(mellea|test|cli|docs).*\.(py|ipynb)$'

- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.7.8
Expand Down
20 changes: 15 additions & 5 deletions test/backends/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,20 @@
from mellea.backends.formatter import TemplateFormatter
from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.types import ModelOption
from mellea.stdlib.base import (CBlock, ChatContext, Context, ModelOutputThunk,
SimpleContext)
from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement,
Requirement, ValidationResult,
default_output_to_bool)
from mellea.stdlib.base import (
CBlock,
ChatContext,
Context,
ModelOutputThunk,
SimpleContext,
)
from mellea.stdlib.requirement import (
ALoraRequirement,
LLMaJRequirement,
Requirement,
ValidationResult,
default_output_to_bool,
)


@pytest.fixture(scope="module")
Expand All @@ -40,6 +49,7 @@ def session(backend):
yield session
session.reset()


@pytest.mark.qualitative
def test_adapters(backend):
assert len(backend._added_adapters.items()) > 0
Expand Down
16 changes: 8 additions & 8 deletions test/backends/test_litellm_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def backend(gh_run: int):
url = url.replace("127.0.0.1", "http://localhost")

return LiteLLMBackend(
model_id=_MODEL_ID,
base_url=url,
model_options={"api_base": url},
model_id=_MODEL_ID, base_url=url, model_options={"api_base": url}
)
else:
return LiteLLMBackend(model_id=_MODEL_ID)
Expand Down Expand Up @@ -111,12 +109,11 @@ def test_litellm_ollama_instruct_options(session):
ModelOption.SEED: 123,
ModelOption.TEMPERATURE: 0.5,
ModelOption.MAX_NEW_TOKENS: 100,

# Ollama thinking controls currently broken on Granite; see
# Ollama thinking controls currently broken on Granite; see
# https://github.com/ollama/ollama/issues/10983
# TODO: Re-enable when this upstream bug gets fixed.
#ModelOption.THINKING: True,
#"reasoning_effort": True,
# ModelOption.THINKING: True,
# "reasoning_effort": True,
"homer_simpson": "option should be kicked out",
}

Expand Down Expand Up @@ -144,6 +141,7 @@ def is_happy(text: str) -> bool:
# should yield to true - but, of course, is model dependent
assert h is True


async def test_generate_from_raw(session):
prompts = [
"what is 1+1?",
Expand All @@ -157,7 +155,9 @@ async def test_generate_from_raw(session):
actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx
)

assert len(results) == 1, "ollama doesn't support batching; litellm should send a single message containing all prompts"
assert len(results) == 1, (
"ollama doesn't support batching; litellm should send a single message containing all prompts"
)
assert results[0].value is not None


Expand Down
11 changes: 4 additions & 7 deletions test/backends/test_litellm_watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,15 @@ def test_multiple_sync_funcs(session):

@pytest.mark.qualitative
async def test_generate_from_raw(session):
prompts = [
"what is 1+1?",
"what is 2+2?",
"what is 3+3?",
"what is 4+2+2?",
]
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+2+2?"]

results = await session.backend.generate_from_raw(
actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx
)

assert len(results) == 1, "litellm converts a batch request for watsonx into a single message"
assert len(results) == 1, (
"litellm converts a batch request for watsonx into a single message"
)
assert results[0].value is not None


Expand Down
1 change: 1 addition & 0 deletions test/backends/test_openai_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ async def test_generate_from_raw(m_session):
actions=[CBlock(value=prompt) for prompt in prompts], ctx=m_session.ctx
)


# Default OpenAI implementation doesn't support structured outputs for the completions API.
# def test_generate_from_raw_with_format(self):
# prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
Expand Down
15 changes: 11 additions & 4 deletions test/backends/test_openai_vllm/test_openai_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from mellea.backends.openai import OpenAIBackend
from mellea.backends.types import ModelOption, _ServerType
from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk
from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement,
Requirement, req)
from mellea.stdlib.requirement import (
ALoraRequirement,
LLMaJRequirement,
Requirement,
req,
)

# The vllm tests are disabled by default, because we need a test environment with the vLLM server running.
# We use an env var VLLM_TESTS_ENABLED to enable these tests.
Expand Down Expand Up @@ -138,8 +142,11 @@ class TestOpenAIALoraStuff:
base_url="http://localhost:8000/v1",
api_key="EMPTY",
)
backend.add_adapter(GraniteCommonAdapter("requirement_check",
base_model_name=backend.base_model_name))
backend.add_adapter(
GraniteCommonAdapter(
"requirement_check", base_model_name=backend.base_model_name
)
)

m = MelleaSession(backend, ctx=ChatContext())

Expand Down
1 change: 1 addition & 0 deletions test/stdlib_basics/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ def format_for_llm(self) -> str:
c = _ClosuredComponent()
assert len(c.parts()) == 0


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions test/stdlib_basics/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mellea.stdlib.base import Document
from mellea.stdlib.chat import Message


def test_message_with_docs():
doc = Document("I'm text!", "Im a title!")
msg = Message("user", "hello", documents=[doc])
Expand Down
93 changes: 72 additions & 21 deletions test/stdlib_basics/test_genslot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,27 @@
from mellea.backends.model_ids import META_LLAMA_3_2_1B
from mellea.backends.ollama import OllamaModelBackend
from mellea.stdlib.base import ChatContext, Context
from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot, PreconditionException, SyncGenerativeSlot
from mellea.stdlib.genslot import (
AsyncGenerativeSlot,
GenerativeSlot,
PreconditionException,
SyncGenerativeSlot,
)
from mellea.stdlib.requirement import Requirement, simple_validate
from mellea.stdlib.sampling.base import RejectionSamplingStrategy
from mellea.stdlib.session import MelleaSession


@pytest.fixture(scope="module")
def backend(gh_run: int):
"""Shared backend."""
if gh_run == 1:
return OllamaModelBackend(
model_id=META_LLAMA_3_2_1B.ollama_name, # type: ignore
model_id=META_LLAMA_3_2_1B.ollama_name # type: ignore
)
else:
return OllamaModelBackend(
model_id="granite3.3:8b",
)
return OllamaModelBackend(model_id="granite3.3:8b")


@generative
def classify_sentiment(text: str) -> Literal["positive", "negative"]: ...
Expand Down Expand Up @@ -81,26 +86,66 @@ async def test_async_gen_slot(session):
r1 = async_write_short_sentence(session, topic="cats")
r2 = async_write_short_sentence(session, topic="dogs")

r3, c3 = await async_write_short_sentence(context=session.ctx, backend=session.backend, topic="fish")
r3, c3 = await async_write_short_sentence(
context=session.ctx, backend=session.backend, topic="fish"
)
results = await asyncio.gather(r1, r2)

assert isinstance(r3, str)
assert isinstance(c3, Context)
assert len(results) == 2


@pytest.mark.parametrize(
"arg_choices,kwarg_choices,errs",
[
pytest.param(["m"], ["func1", "func2", "func3"], False, id="session"),
pytest.param(["context"], ["backend"], False, id="context and backend"),
pytest.param(["backend"], ["func1", "func2", "func3"], True, id="backend without context"),
pytest.param(
["backend"], ["func1", "func2", "func3"], True, id="backend without context"
),
pytest.param(["m"], ["m"], True, id="duplicate arg and kwarg"),
pytest.param(["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], [], True, id="original func args as positional args"),
pytest.param([], ["m", "func1", "func2", "func3"], False, id="session and func as kwargs"),
pytest.param([], ["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], False, id="all kwargs"),
pytest.param([], ["func1", "m", "func2", "requirements", "func3"], False, id="interspersed kwargs"),
pytest.param([], [], True, id="missing required args")
]
pytest.param(
[
"m",
"precondition_requirements",
"requirements",
"strategy",
"model_options",
"func1",
"func2",
"func3",
],
[],
True,
id="original func args as positional args",
),
pytest.param(
[], ["m", "func1", "func2", "func3"], False, id="session and func as kwargs"
),
pytest.param(
[],
[
"m",
"precondition_requirements",
"requirements",
"strategy",
"model_options",
"func1",
"func2",
"func3",
],
False,
id="all kwargs",
),
pytest.param(
[],
["func1", "m", "func2", "requirements", "func3"],
False,
id="interspersed kwargs",
),
pytest.param([], [], True, id="missing required args"),
],
)
def test_arg_extraction(backend, arg_choices, kwarg_choices, errs):
"""Tests the internal extract_args_and_kwargs function.
Expand Down Expand Up @@ -156,35 +201,40 @@ def test_arg_extraction(backend, arg_choices, kwarg_choices, errs):
except Exception as e:
found_err = True
err = e

if errs:
assert found_err, "expected an exception and got none"
else:
assert not found_err, f"got unexpected err: {err}"


def test_disallowed_parameter_names():
with pytest.raises(ValueError):

@generative
def test(backend):
...
def test(backend): ...


def test_precondition_failure(session):
with pytest.raises(PreconditionException):
classify_sentiment(
m=session,
text="hello",
precondition_requirements=[
Requirement("forced failure", validation_fn=simple_validate(lambda x: (False, "")))
]
Requirement(
"forced failure",
validation_fn=simple_validate(lambda x: (False, "")),
)
],
)


def test_requirement(session):
classify_sentiment(
m=session,
text="hello",
requirements=["req1", "req2", Requirement("req3")]
m=session, text="hello", requirements=["req1", "req2", Requirement("req3")]
)


def test_with_no_args(session):
@generative
def generate_text() -> str:
Expand All @@ -193,5 +243,6 @@ def generate_text() -> str:

generate_text(m=session)


if __name__ == "__main__":
pytest.main([__file__])
4 changes: 3 additions & 1 deletion test/stdlib_basics/test_reqlib_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pytest
from mellea.stdlib.reqlib.tools import _name2str


def test_name2str():
"""Test handling when no Python code is present."""

def test123():
pass

assert _name2str(test123) == "test123"
assert _name2str("test1234") == "test1234"

3 changes: 1 addition & 2 deletions test/stdlib_basics/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,11 @@ def test_session_copy_with_context_ops(m_session):


class TestPowerup:
def hello(m:MelleaSession):
def hello(m: MelleaSession):
return "hello"


def test_powerup(m_session):

MelleaSession.powerup(TestPowerup)

assert "hello" == m_session.hello()
Expand Down
Loading