Skip to content

Commit

Permalink
Add support for reasoning_effort parameter for OpenAI o1 models
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse committed Jan 8, 2025
1 parent bb1dcac commit a28a3a1
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ llm_generation_chain(
return_top_logprobs: int = 0,
bind_prompt_values: Optional[dict] = None,
force_skip_cache: bool = False,
reasoning_effort: Optional[str] = None,
) # returns a LangChain chain the accepts inputs and returns a string as output
load_config_from_file(config_file: str)
pprint_chain() # can be used to print inputs or outputs of a LangChain chain.
Expand Down
1 change: 1 addition & 0 deletions chainlite/chat_lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from pydantic import BaseModel, Field

from chainlite.llm_config import GlobalVars
from chainlite.llm_output import ToolOutput

logger = logging.getLogger(__name__)

Expand Down
13 changes: 13 additions & 0 deletions chainlite/llm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def llm_generation_chain(
return_top_logprobs: int = 0,
bind_prompt_values: Optional[dict] = None,
force_skip_cache: bool = False,
reasoning_effort: Optional[str] = None,
) -> Runnable:
"""
Constructs a LangChain generation chain for LLM response utilizing LLM APIs prescribed in the ChainLite config file.
Expand All @@ -286,6 +287,7 @@ def llm_generation_chain(
return_top_logprobs (int, optional): If > 0, will return the top logprobs for each token, so the output will be Tuple[str, dict]. Defaults to 0.
bind_prompt_values (dict, optional): A dictionary containing {Variable: str : Value}. Binds values to the prompt. Additional variables can be provided when the chain is called. Defaults to {}.
force_skip_cache (bool, optional): If True, will force the LLM to skip the cache, and the new value won't be saved in cache either. Defaults to False.
reasoning_effort (str, optional): The reasoning effort to use for reasoning models like o1. Must be one of "low", "medium", "high". Defaults to medium. Cache is not sensitive to the value of this parameter, meaning that the cache is shared across reasoning effort values.
Returns:
Runnable: The language model generation chain
Expand All @@ -294,6 +296,13 @@ def llm_generation_chain(
IndexError: Raised when no engine matches the provided string in the LLM APIs configured, or the API key is not found.
"""

assert reasoning_effort in [
None,
"low",
"medium",
"high",
], f"Invalid reasoning_effort: {reasoning_effort}. Valid values are 'low', 'medium', 'high'."

if (
sum(
[
Expand Down Expand Up @@ -412,6 +421,10 @@ def llm_generation_chain(
model_kwargs["logprobs"] = True
model_kwargs["top_logprobs"] = return_top_logprobs

if reasoning_effort:
# only include it when explicitly set, because most models do not support it
model_kwargs["reasoning_effort"] = reasoning_effort

if tools:
function_json = [
{"type": "function", "function": litellm.utils.function_to_dict(t)}
Expand Down
2 changes: 1 addition & 1 deletion llm_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ llm_endpoints:
gpt-4o-openai: gpt-4o-2024-08-06 # you can specify which version of the model you want
gpt-4o: gpt-4o # you can leave it to OpenAI to select the latest model version for you
gpt-4o-another-one: gpt-4o # "model" names, which are on the right side-hand of a mapping, do not need to be unique
o1: o1-preview-2024-09-12
o1: o1

# Example of OpenAI fine-tuned model
- api_base: https://api.openai.com/v1
Expand Down
17 changes: 10 additions & 7 deletions tasks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,18 @@ def start_redis(c, redis_port: int = DEFAULT_REDIS_PORT):


@task(pre=[load_api_keys, start_redis], aliases=["test"])
def tests(c, log_level="info", parallel=False):
def tests(c, log_level="info", parallel=False, test_file: str = None):
"""Run tests using pytest"""

test_files = [
"./tests/test_llm_generate.py",
"./tests/test_llm_structured_output.py",
"./tests/test_function_calling.py",
"./tests/test_logprobs.py",
]
if test_file:
test_files = [f"./tests/{test_file}"]
else:
test_files = [
"./tests/test_llm_generate.py",
"./tests/test_llm_structured_output.py",
"./tests/test_function_calling.py",
"./tests/test_logprobs.py",
]

pytest_command = (
f"pytest "
Expand Down
5 changes: 5 additions & 0 deletions tests/reasoning.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# instruction

# input
Write a Python program that counts the number of substrings in a given string that are palindromes.
A palindrome is a string that reads the same forwards and backwards. The program should take a string as input and output the number of palindromic substrings in the string.
22 changes: 18 additions & 4 deletions tests/test_llm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,6 @@ async def test_cached_batching():
), "The cost should not change after a cached batched LLM call"






@pytest.mark.asyncio(scope="session")
async def test_o1_model():
response = await llm_generation_chain(
Expand Down Expand Up @@ -243,3 +239,21 @@ async def async_function(i):
async_function, test_inputs, max_concurrency, desc
)
assert ret == list(test_inputs)


@pytest.mark.asyncio(scope="session")
async def test_o1_reasoning_effort():
for reasoning_effort in ["low", "medium", "high"]:
start_time = time.time()
response = await llm_generation_chain(
template_file="tests/reasoning.prompt",
engine="o1",
max_tokens=2000,
force_skip_cache=True,
reasoning_effort=reasoning_effort,
).ainvoke({})
print(response)
print(
f"Reasoning effort: {reasoning_effort}, Time taken: {time.time() - start_time}"
)
assert response

0 comments on commit a28a3a1

Please sign in to comment.