Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
6e98053
Adding a fix to pass `reasoning_effort` in conditionally
avinash2692 Dec 24, 2025
ab9e56b
adding tests
avinash2692 Dec 24, 2025
0862cf0
Fixes #274
nrfulton Dec 26, 2025
29481f7
Adds GPT 5.1 model identifier.
nrfulton Dec 26, 2025
1426ee9
Changes OpenAI Backend default model_id to GPT 5.1.
nrfulton Dec 26, 2025
c11fbef
Fixes bug: GenSlots did not work with OpenAI platform.
nrfulton Dec 26, 2025
0bde6ec
Adds inline documentation for OpenAI model options monkey patching.
nrfulton Dec 26, 2025
4d87c83
removes debug print stmt.
nrfulton Dec 26, 2025
f87f86b
adding a comment about reasoning_effort in openai sdk
avinash2692 Jan 5, 2026
e7e161b
Merge branch 'fix/270-openai-reasoning-effort' of https://github.com/…
avinash2692 Jan 5, 2026
b6d16a6
Merge branch 'main' into fix/270-openai-reasoning-effort
avinash2692 Jan 6, 2026
a94205d
removing all instances of hf_model_id in openai backend
avinash2692 Jan 6, 2026
1e7c1b4
removing apply_chat_template and adding assertions for env variable
avinash2692 Jan 6, 2026
a695cb4
adding some tests for param checking
avinash2692 Jan 6, 2026
41a0c62
changing env variable handling logic.
avinash2692 Jan 6, 2026
c905843
base_url check is now a warning
avinash2692 Jan 6, 2026
0a7747a
fix: change warning message in openai.py
jakelorocco Jan 6, 2026
d0ecfc7
marking test as qualitative cause it's causing timeouts in github act…
avinash2692 Jan 6, 2026
17c2862
Merge branch 'fix/270-openai-reasoning-effort' of https://github.com/…
avinash2692 Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mellea/backends/model_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ModelIdentifier:
ollama_name: str | None = None
watsonx_name: str | None = None
mlx_name: str | None = None
openai_name: str | None = None

hf_tokenizer_name: str | None = None # if None, is the same as hf_model_name

Expand Down Expand Up @@ -134,9 +135,9 @@ class ModelIdentifier:

QWEN3_14B = ModelIdentifier(hf_model_name="Qwen/Qwen3-14B", ollama_name="qwen3:14b")

######################
#### OpenAI models ###
######################
###########################
#### OpenAI open models ###
###########################

OPENAI_GPT_OSS_20B = ModelIdentifier(
hf_model_name="openai/gpt-oss-20b", ollama_name="gpt-oss:20b"
Expand All @@ -145,6 +146,12 @@ class ModelIdentifier:
hf_model_name="openai/gpt-oss-120b", ollama_name="gpt-oss:120b"
)

###########################
#### OpenAI prop models ###
###########################

OPENAI_GPT_5_1 = ModelIdentifier(openai_name="gpt-5.1")

#####################
#### Misc models ####
#####################
Expand Down
88 changes: 64 additions & 24 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class OpenAIBackend(FormatterBackend, AdapterMixin):

def __init__(
self,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_MICRO_3B,
model_id: str | ModelIdentifier = model_ids.OPENAI_GPT_5_1,
formatter: Formatter | None = None,
base_url: str | None = None,
model_options: dict | None = None,
Expand Down Expand Up @@ -142,26 +142,30 @@ def __init__(

self.default_to_constraint_checking_alora = default_to_constraint_checking_alora

self._model_id = model_id
match model_id:
case str():
self._hf_model_id = model_id
self._model_id = model_id
case ModelIdentifier():
assert model_id.hf_model_name is not None, (
"model_id is None. This can also happen if the ModelIdentifier has no hf_model_id name set."
assert model_id.openai_name is not None, (
"model_id is None. This can also happen if the ModelIdentifier has no `openai_name` name set."
)
self._hf_model_id = model_id.hf_model_name
self._model_id = model_id.openai_name

if base_url is None:
self._base_url = "http://localhost:11434/v1" # ollama
else:
self._base_url = base_url
if api_key is None:
FancyLogger.get_logger().warning(
"You are using an OpenAI backend with no api_key. Because no API key was provided, mellea assumes you intend to use the openai-compatible interface to your local ollama instance. If you intend to use OpenAI's platform you must specify your API key when instantiating your Mellea session/backend object."
)
self._base_url: str | None = "http://localhost:11434/v1" # ollama
self._api_key = "ollama"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are close to defaults that make sense here. I think if the user specifies a base_url we should always use that base_url (even if no apikey is set). I also wonder if we should default the apikey to ollama in those situations.

Otherwise, we have no way to target arbitrary localhost ports that don't require an apikey.

For example (and this isn't the best since it uses LiteLLM and we have a separate backend for that), LiteLLM has a proxy that you can run locally. This proxy stores the apikey information itself; so you can target an arbitrary localhost port without an apikey.

My proposed solution would be to just set the parameter default values to work for the ollama version (ie api_key="ollama" and base_url="http://localhost:11434/v1"). Then users can override these values. I think this would also allow users to explicitly set api_key / base_url to None and have the underlying OpenAI SDK still automatically pick up their env vars (without the risk of users accidentally incurring expenses).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consensus: just pass the args through to the openai sdk. Don't do argument handling such as this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the final verdict here was to not do any fancy handling and just pass args through as None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, sorry missed this.

else:
self._base_url = base_url
self._api_key = api_key

self._server_type = _server_type(self._base_url)
self._server_type: _ServerType = (
_server_type(self._base_url)
if self._base_url is not None
else _ServerType.OPENAI
) # type: ignore

self._openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)

Expand Down Expand Up @@ -598,14 +602,38 @@ async def _generate_from_chat_context_standard(

extra_params: dict[str, Any] = {}
if _format is not None:
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}
if self._server_type == _ServerType.OPENAI:
# The OpenAI platform requires that additionalProperties=False on all response_format schemas.
# However, not all schemas generates by Mellea include additionalProperties.
# GenerativeSlot, in particular, does not add this property.
# The easiest way to address this disparity between OpenAI and other inference providers is to
# monkey-patch the response format exactly when we are actually using the OpenAI server.
#
# This only addresses the additionalProperties=False constraint.
# Other constraints we should be checking/patching are described here:
# https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat
monkey_patched_response_schema = _format.model_json_schema()
monkey_patched_response_schema["additionalProperties"] = False
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": monkey_patched_response_schema,
"strict": True,
},
}
else:
FancyLogger().get_logger().warning(
"Mellea assumes you are NOT using the OpenAI platform, and that other model providers have less strict requirements on support JSON schemas passed into `format=`. If you encounter a server-side error following this message, then you found an exception to this assumption. Please open an issue at github.com/generative_computing/mellea with this stack trace and your inference engine / model provider."
)
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}

# Append tool call information if applicable.
tools: dict[str, Callable] = dict()
Expand All @@ -631,15 +659,21 @@ async def _generate_from_chat_context_standard(
formatted_tools = convert_tools_to_json(tools)
use_tools = len(formatted_tools) > 0

# Build optional reasoning parameters
# NOTE: the openai SDK doesn't like it if you pass `reasoning_effort` param to a non-reasoning model e.g. gpt4o
reasoning_params = {}
if thinking is not None:
reasoning_params["reasoning_effort"] = thinking

chat_response: Coroutine[
Any, Any, ChatCompletion | openai.AsyncStream[ChatCompletionChunk]
] = self._async_client.chat.completions.create(
model=self._hf_model_id,
model=self._model_id,
messages=conversation, # type: ignore
reasoning_effort=thinking, # type: ignore
tools=formatted_tools if use_tools else None, # type: ignore
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
**extra_params,
**reasoning_params, # type: ignore
**self._make_backend_specific_and_remove(
model_opts, is_chat_context=ctx.is_chat_context
),
Expand Down Expand Up @@ -807,7 +841,7 @@ async def generate_from_raw(
try:
completion_response: Completion = (
await self._async_client.completions.create(
model=self._hf_model_id,
model=self._model_id,
prompt=prompts,
extra_body=extra_body,
**self._make_backend_specific_and_remove(
Expand Down Expand Up @@ -860,7 +894,10 @@ async def generate_from_raw(
@property
def base_model_name(self):
"""Returns the base_model_id of the model used by the backend. For example, `granite-3.3-8b-instruct` for `ibm-granite/granite-3.3-8b-instruct`."""
return self._hf_model_id.split("/")[1]
if "/" in self._model_id:
return self._model_id.split("/")[1]
else:
return self._model_id

def add_adapter(self, adapter: OpenAIAdapter):
"""Adds the given adapter to the backend. Must not have been added to a different backend."""
Expand Down Expand Up @@ -976,10 +1013,13 @@ def apply_chat_template(self, chat: list[dict[str, str]]):
from transformers import AutoTokenizer

if not hasattr(self, "_tokenizer"):
assert self._base_url, (
"The OpenAI Platform does not support adapters. You must specify a _base_url when using adapters."
)
match _server_type(self._base_url):
case _ServerType.LOCALHOST:
self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037
AutoTokenizer.from_pretrained(self._hf_model_id)
AutoTokenizer.from_pretrained(self._model_id)
)
case _ServerType.OPENAI:
raise Exception(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing something, but I don't see this function being utilized anywhere. I see other functions with the same name, but I don't see an OpenAIBackend.apply_chat_template anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this was explicitly being used anywhere. But seems like based on the ServerType code being set, Fred seems to have touched that part of the code .. so I assumed that he might be using it somewhere in his code base. I might have to run the adapters tests locally to test that out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to remove it for now and then figure out the repercussions later.

Expand Down
40 changes: 40 additions & 0 deletions test/backends/test_openai_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,46 @@ async def get_client_async():
assert len(backend._client_cache.cache.values()) == 2


async def test_reasoning_effort_conditional_passing(backend):
"""Test that reasoning_effort is only passed to API when not None."""
from unittest.mock import AsyncMock, MagicMock, patch

ctx = ChatContext()
ctx = ctx.add(CBlock(value="Test"))

mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = "Response"
mock_response.choices[0].message.role = "assistant"

# Test 1: reasoning_effort should NOT be passed when not specified
with patch.object(
backend._async_client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_response
await backend.generate_from_chat_context(
CBlock(value="Hi"), ctx, model_options={}
)
call_kwargs = mock_create.call_args.kwargs
assert "reasoning_effort" not in call_kwargs, (
"reasoning_effort should not be passed when not specified"
)

# Test 2: reasoning_effort SHOULD be passed when specified
with patch.object(
backend._async_client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_response
await backend.generate_from_chat_context(
CBlock(value="Hi"), ctx, model_options={ModelOption.THINKING: "medium"}
)
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs.get("reasoning_effort") == "medium", (
"reasoning_effort should be passed with correct value when specified"
)


if __name__ == "__main__":
import pytest

Expand Down
Loading