Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 52 additions & 44 deletions mirascope/llm/_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,62 +240,70 @@ def wrapper(
]:
fn.__mirascope_call__ = True # pyright: ignore [reportFunctionMemberAccess]
if fn_is_async(fn):
# Create a wrapper function that captures the current context when called
@wraps(fn)
def wrapper_with_context(
*args: _P.args, **kwargs: _P.kwargs
) -> Awaitable[

async def _async_call_impl(
current_context: Any, # noqa: ANN401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get a better type here than Any?

*args: _P.args,
**kwargs: _P.kwargs,
) -> (
CallResponse
| Stream
| _ResponseModelT
| _ParsedOutputT
| (_ResponseModelT | CallResponse)
]:
# Capture the context at call time
current_context = get_current_context()

# Define an async function that uses the captured context
async def context_bound_inner_async() -> (
CallResponse
| Stream
| _ResponseModelT
| _ParsedOutputT
| (_ResponseModelT | CallResponse)
):
# Apply any context overrides to the original call args
effective_call_args = apply_context_overrides_to_call_args(
original_call_args, context_override=current_context
)
):
# Apply any context overrides to the original call args
effective_call_args = apply_context_overrides_to_call_args(
original_call_args, context_override=current_context
)

# Get the appropriate provider call function with the possibly overridden provider
effective_provider = effective_call_args["provider"]
effective_client = effective_call_args["client"]
# Get the appropriate provider call function with the possibly overridden provider
effective_provider = effective_call_args["provider"]
effective_client = effective_call_args["client"]

if effective_provider in get_args(LocalProvider):
provider_call, effective_client = _get_local_provider_call(
cast(LocalProvider, effective_provider),
effective_client,
True,
)
effective_call_args["client"] = effective_client
else:
provider_call = _get_provider_call(
cast(Provider, effective_provider)
)
if effective_provider in get_args(LocalProvider):
provider_call, effective_client = _get_local_provider_call(
cast(LocalProvider, effective_provider),
effective_client,
True,
)
effective_call_args["client"] = effective_client
else:
provider_call = _get_provider_call(
cast(Provider, effective_provider)
)

# Use the provider-specific call function with overridden args
call_kwargs = dict(effective_call_args)
del call_kwargs["provider"] # Not a parameter to provider_call
# Use the provider-specific call function with overridden args
call_kwargs = dict(effective_call_args)
del call_kwargs["provider"] # Not a parameter to provider_call

# Get decorated function using provider_call
decorated = provider_call(**call_kwargs)(fn)
# Get decorated function using provider_call
decorated = provider_call(**call_kwargs)(fn)

# Call the decorated function and wrap the result
result = await decorated(*args, **kwargs)
return _wrap_result(result)
# Call the decorated function and wrap the result
result = await decorated(*args, **kwargs)
return _wrap_result(result)

return context_bound_inner_async()
# Create a sync wrapper that captures context and returns a coroutine.
# Context is captured at call time (not await time) for asyncio.gather support.
@wraps(fn)
def wrapper_with_context(
*args: _P.args, **kwargs: _P.kwargs
) -> Awaitable[
CallResponse
| Stream
| _ResponseModelT
| _ParsedOutputT
| (_ResponseModelT | CallResponse)
]:
# Capture context at call time (when coroutine is created)
current_context = get_current_context()
# Return the coroutine with the captured context
return _async_call_impl(current_context, *args, **kwargs)

# Mark for tenacity compatibility: tenacity checks __call__ for async
# This allows tenacity to detect this as an async callable
wrapper_with_context.__call__ = _async_call_impl # pyright: ignore [reportAttributeAccessIssue]
wrapper_with_context._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
wrapper_with_context._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]

Expand Down
135 changes: 135 additions & 0 deletions tests/llm/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,138 @@ async def dummy_async_function():
assert captured_args_list[1]["model"] == "claude-3-5-sonnet", (
"Context model override was not applied when using asyncio.gather"
)


@pytest.mark.asyncio
async def test_async_call_is_coroutine_callable():
"""Test that async decorated functions are properly detected as coroutine callables.

This is important for compatibility with tenacity's @retry decorator,
which uses its is_coroutine_callable helper to determine if a function is async.
Tenacity checks the __call__ attribute for async functions.
"""
import functools
import inspect

def is_coroutine_callable(call):
"""Replicate tenacity's is_coroutine_callable logic."""
if inspect.isclass(call):
return False
if inspect.iscoroutinefunction(call):
return True
partial_call = isinstance(call, functools.partial) and call.func
# This replicates tenacity's exact logic which checks __call__ attribute
dunder_call = partial_call or getattr(call, "__call__", None) # noqa: B004
return inspect.iscoroutinefunction(dunder_call)

def dummy_async_provider_call(
model,
stream,
tools,
response_model,
output_parser,
json_mode,
call_params,
client,
):
def wrapper(fn):
async def inner(*args, **kwargs):
return ConcreteResponse(
metadata=Metadata(),
response={},
tool_types=None,
prompt_template=None,
fn_args={},
dynamic_config={},
messages=[],
call_params=DummyCallParams(),
call_kwargs=BaseCallKwargs(),
user_message_param=None,
start_time=0,
end_time=0,
)

return inner

return wrapper

with patch(
"mirascope.llm._call._get_provider_call",
return_value=dummy_async_provider_call,
):

@call(provider="openai", model="gpt-4o-mini")
async def dummy_async_function(): ...

# The decorated function should be recognized as a coroutine callable
# by tenacity's is_coroutine_callable helper
assert is_coroutine_callable(dummy_async_function), (
"Async decorated function should be detected as a coroutine callable "
"for tenacity @retry compatibility"
)

# Verify it still works correctly
res = await dummy_async_function()
assert isinstance(res, CallResponse)


@pytest.mark.asyncio
async def test_tenacity_retry_with_async_call():
"""Test that tenacity's @retry decorator works with async call functions.

This test verifies the fix for issue #989 where @retry with collect_errors
did not work with async methods.
"""
from tenacity import retry, stop_after_attempt

call_count = 0

def dummy_async_provider_call(
model,
stream,
tools,
response_model,
output_parser,
json_mode,
call_params,
client,
):
def wrapper(fn):
async def inner(*args, **kwargs):
nonlocal call_count
call_count += 1
# Fail the first 2 attempts, succeed on the 3rd
if call_count < 3:
raise ValueError("Simulated failure")
return ConcreteResponse(
metadata=Metadata(),
response={},
tool_types=None,
prompt_template=None,
fn_args={},
dynamic_config={},
messages=[],
call_params=DummyCallParams(),
call_kwargs=BaseCallKwargs(),
user_message_param=None,
start_time=0,
end_time=0,
)

return inner

return wrapper

with patch(
"mirascope.llm._call._get_provider_call",
return_value=dummy_async_provider_call,
):

@retry(stop=stop_after_attempt(3))
@call(provider="openai", model="gpt-4o-mini")
async def retry_async_function(): ...

# This should succeed on the 3rd attempt
res = await retry_async_function()
assert isinstance(res, CallResponse)
assert call_count == 3, "Function should have been called 3 times"