Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
6e98053
Adding a fix to pass `reasoning_effort` in conditionally
avinash2692 Dec 24, 2025
ab9e56b
adding tests
avinash2692 Dec 24, 2025
0862cf0
Fixes #274
nrfulton Dec 26, 2025
29481f7
Adds GPT 5.1 model identifier.
nrfulton Dec 26, 2025
1426ee9
Changes OpenAI Backend default model_id to GPT 5.1.
nrfulton Dec 26, 2025
c11fbef
Fixes bug: GenSlots did not work with OpenAI platform.
nrfulton Dec 26, 2025
0bde6ec
Adds inline documentation for OpenAI model options monkey patching.
nrfulton Dec 26, 2025
4d87c83
removes debug print stmt.
nrfulton Dec 26, 2025
f87f86b
adding a comment about reasoning_effort in openai sdk
avinash2692 Jan 5, 2026
e7e161b
Merge branch 'fix/270-openai-reasoning-effort' of https://github.com/…
avinash2692 Jan 5, 2026
b6d16a6
Merge branch 'main' into fix/270-openai-reasoning-effort
avinash2692 Jan 6, 2026
a94205d
removing all instances of hf_model_id in openai backend
avinash2692 Jan 6, 2026
1e7c1b4
removing apply_chat_template and adding assertions for env variable
avinash2692 Jan 6, 2026
a695cb4
adding some tests for param checking
avinash2692 Jan 6, 2026
41a0c62
changing env variable handling logic.
avinash2692 Jan 6, 2026
c905843
base_url check is now a warning
avinash2692 Jan 6, 2026
0a7747a
fix: change warning message in openai.py
jakelorocco Jan 6, 2026
d0ecfc7
marking test as qualitative cause it's causing timeouts in github act…
avinash2692 Jan 6, 2026
17c2862
Merge branch 'fix/270-openai-reasoning-effort' of https://github.com/…
avinash2692 Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mellea/backends/model_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ModelIdentifier:
ollama_name: str | None = None
watsonx_name: str | None = None
mlx_name: str | None = None
openai_name: str | None = None

hf_tokenizer_name: str | None = None # if None, is the same as hf_model_name

Expand Down Expand Up @@ -134,9 +135,9 @@ class ModelIdentifier:

QWEN3_14B = ModelIdentifier(hf_model_name="Qwen/Qwen3-14B", ollama_name="qwen3:14b")

######################
#### OpenAI models ###
######################
###########################
#### OpenAI open models ###
###########################

OPENAI_GPT_OSS_20B = ModelIdentifier(
hf_model_name="openai/gpt-oss-20b", ollama_name="gpt-oss:20b"
Expand All @@ -145,6 +146,12 @@ class ModelIdentifier:
hf_model_name="openai/gpt-oss-120b", ollama_name="gpt-oss:120b"
)

###########################
#### OpenAI prop models ###
###########################

OPENAI_GPT_5_1 = ModelIdentifier(openai_name="gpt-5.1")

#####################
#### Misc models ####
#####################
Expand Down
119 changes: 73 additions & 46 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import inspect
import json
import os
from collections.abc import Callable, Coroutine
from copy import deepcopy
from enum import Enum
Expand Down Expand Up @@ -72,7 +73,7 @@ class OpenAIBackend(FormatterBackend, AdapterMixin):

def __init__(
self,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_MICRO_3B,
model_id: str | ModelIdentifier = model_ids.OPENAI_GPT_5_1,
formatter: Formatter | None = None,
base_url: str | None = None,
model_options: dict | None = None,
Expand Down Expand Up @@ -142,26 +143,38 @@ def __init__(

self.default_to_constraint_checking_alora = default_to_constraint_checking_alora

self._model_id = model_id
match model_id:
case str():
self._hf_model_id = model_id
self._model_id = model_id
case ModelIdentifier():
assert model_id.hf_model_name is not None, (
"model_id is None. This can also happen if the ModelIdentifier has no hf_model_id name set."
assert model_id.openai_name is not None, (
"model_id is None. This can also happen if the ModelIdentifier has no `openai_name` name set."
)
self._hf_model_id = model_id.hf_model_name
self._model_id = model_id.openai_name

if base_url is None:
self._base_url = "http://localhost:11434/v1" # ollama
else:
self._base_url = base_url
if api_key is None:
self._api_key = "ollama"
else:
self._api_key = api_key
# Use provided parameters or fall back to environment variables
self._api_key = api_key
self._base_url = base_url

self._server_type = _server_type(self._base_url)
# Validate that we have the required configuration
if self._api_key is None and os.getenv("OPENAI_API_KEY") is None:
raise ValueError(
"OPENAI_API_KEY or api_key is required but not set. Please either:\n"
" 1. Set the environment variable: export OPENAI_API_KEY='your-key-here'\n"
" 2. Pass it as a parameter: OpenAIBackend(api_key='your-key-here')"
)

if self._base_url is None and os.getenv("OPENAI_BASE_URL") is None:
FancyLogger.get_logger().warning(
"OPENAI_BASE_URL or base_url is not set.\n"
"The openai SDK is going to assume that the base_url is `https://api.openai.com/v1`"
)

self._server_type: _ServerType = (
_server_type(self._base_url)
if self._base_url is not None
else _ServerType.OPENAI
) # type: ignore

self._openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)

Expand Down Expand Up @@ -598,14 +611,38 @@ async def _generate_from_chat_context_standard(

extra_params: dict[str, Any] = {}
if _format is not None:
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}
if self._server_type == _ServerType.OPENAI:
# The OpenAI platform requires that additionalProperties=False on all response_format schemas.
# However, not all schemas generates by Mellea include additionalProperties.
# GenerativeSlot, in particular, does not add this property.
# The easiest way to address this disparity between OpenAI and other inference providers is to
# monkey-patch the response format exactly when we are actually using the OpenAI server.
#
# This only addresses the additionalProperties=False constraint.
# Other constraints we should be checking/patching are described here:
# https://platform.openai.com/docs/guides/structured-outputs?api-mode=chat
monkey_patched_response_schema = _format.model_json_schema()
monkey_patched_response_schema["additionalProperties"] = False
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": monkey_patched_response_schema,
"strict": True,
},
}
else:
FancyLogger().get_logger().warning(
"Mellea assumes you are NOT using the OpenAI platform, and that other model providers have less strict requirements on support JSON schemas passed into `format=`. If you encounter a server-side error following this message, then you found an exception to this assumption. Please open an issue at github.com/generative_computing/mellea with this stack trace and your inference engine / model provider."
)
extra_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}

# Append tool call information if applicable.
tools: dict[str, Callable] = dict()
Expand All @@ -631,15 +668,21 @@ async def _generate_from_chat_context_standard(
formatted_tools = convert_tools_to_json(tools)
use_tools = len(formatted_tools) > 0

# Build optional reasoning parameters
# NOTE: the openai SDK doesn't like it if you pass `reasoning_effort` param to a non-reasoning model e.g. gpt4o
reasoning_params = {}
if thinking is not None:
reasoning_params["reasoning_effort"] = thinking

chat_response: Coroutine[
Any, Any, ChatCompletion | openai.AsyncStream[ChatCompletionChunk]
] = self._async_client.chat.completions.create(
model=self._hf_model_id,
model=self._model_id,
messages=conversation, # type: ignore
reasoning_effort=thinking, # type: ignore
tools=formatted_tools if use_tools else None, # type: ignore
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
**extra_params,
**reasoning_params, # type: ignore
**self._make_backend_specific_and_remove(
model_opts, is_chat_context=ctx.is_chat_context
),
Expand Down Expand Up @@ -807,7 +850,7 @@ async def generate_from_raw(
try:
completion_response: Completion = (
await self._async_client.completions.create(
model=self._hf_model_id,
model=self._model_id,
prompt=prompts,
extra_body=extra_body,
**self._make_backend_specific_and_remove(
Expand Down Expand Up @@ -860,7 +903,10 @@ async def generate_from_raw(
@property
def base_model_name(self):
"""Returns the base_model_id of the model used by the backend. For example, `granite-3.3-8b-instruct` for `ibm-granite/granite-3.3-8b-instruct`."""
return self._hf_model_id.split("/")[1]
if "/" in self._model_id:
return self._model_id.split("/")[1]
else:
return self._model_id

def add_adapter(self, adapter: OpenAIAdapter):
"""Adds the given adapter to the backend. Must not have been added to a different backend."""
Expand Down Expand Up @@ -970,22 +1016,3 @@ def list_adapters(self) -> list[str]:
:returns: list of adapter names that are currently registered with this backend
"""
return list(self._loaded_adapters.keys())

def apply_chat_template(self, chat: list[dict[str, str]]):
"""Apply the chat template for the model, if such a model is available (e.g., when it can deduce the huggingface model id)."""
from transformers import AutoTokenizer

if not hasattr(self, "_tokenizer"):
match _server_type(self._base_url):
case _ServerType.LOCALHOST:
self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037
AutoTokenizer.from_pretrained(self._hf_model_id)
)
case _ServerType.OPENAI:
raise Exception(
"apply_chat_template is called while targeting a server at openai.com. "
"This is not supported --- openai.com does not support Activated Lora. "
"Use a locally served vllm instance. "
)

return self._tokenizer.apply_chat_template(chat, tokenize=False)
75 changes: 75 additions & 0 deletions test/backends/test_openai_ollama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# test/rits_backend_tests/test_openai_integration.py
import asyncio
import os
from unittest.mock import patch

import openai
import pydantic
Expand Down Expand Up @@ -216,6 +217,80 @@ async def get_client_async():
assert len(backend._client_cache.cache.values()) == 2


async def test_reasoning_effort_conditional_passing(backend):
"""Test that reasoning_effort is only passed to API when not None."""
from unittest.mock import AsyncMock, MagicMock, patch

ctx = ChatContext()
ctx = ctx.add(CBlock(value="Test"))

mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = "Response"
mock_response.choices[0].message.role = "assistant"

# Test 1: reasoning_effort should NOT be passed when not specified
with patch.object(
backend._async_client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_response
await backend.generate_from_chat_context(
CBlock(value="Hi"), ctx, model_options={}
)
call_kwargs = mock_create.call_args.kwargs
assert "reasoning_effort" not in call_kwargs, (
"reasoning_effort should not be passed when not specified"
)

# Test 2: reasoning_effort SHOULD be passed when specified
with patch.object(
backend._async_client.chat.completions, "create", new_callable=AsyncMock
) as mock_create:
mock_create.return_value = mock_response
await backend.generate_from_chat_context(
CBlock(value="Hi"), ctx, model_options={ModelOption.THINKING: "medium"}
)
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs.get("reasoning_effort") == "medium", (
"reasoning_effort should be passed with correct value when specified"
)


def test_api_key_and_base_url_from_parameters():
"""Test that API key and base URL can be set via parameters."""
backend = OpenAIBackend(
model_id="gpt-4", api_key="test-api-key", base_url="https://api.test.com/v1"
)
assert backend._api_key == "test-api-key"
assert backend._base_url == "https://api.test.com/v1"


def test_parameter_overrides_env_variable():
"""Test that explicit parameters override environment variables."""
with patch.dict(
os.environ,
{"OPENAI_API_KEY": "env-api-key", "OPENAI_BASE_URL": "https://api.env.com/v1"},
):
backend = OpenAIBackend(
model_id="gpt-4",
api_key="param-api-key",
base_url="https://api.param.com/v1",
)
assert backend._api_key == "param-api-key"
assert backend._base_url == "https://api.param.com/v1"


def test_missing_api_key_raises_error():
"""Test that missing API key raises ValueError with helpful message."""
with patch.dict(os.environ, {}, clear=True):
with pytest.raises(ValueError) as exc_info:
OpenAIBackend(model_id="gpt-4", base_url="https://api.test.com/v1")
assert "OPENAI_API_KEY or api_key is required but not set" in str(
exc_info.value
)


if __name__ == "__main__":
import pytest

Expand Down
Loading