diff --git a/mirascope/llm/_call.py b/mirascope/llm/_call.py index 528aabfefa..a72a8f61c3 100644 --- a/mirascope/llm/_call.py +++ b/mirascope/llm/_call.py @@ -240,62 +240,70 @@ def wrapper( ]: fn.__mirascope_call__ = True # pyright: ignore [reportFunctionMemberAccess] if fn_is_async(fn): - # Create a wrapper function that captures the current context when called - @wraps(fn) - def wrapper_with_context( - *args: _P.args, **kwargs: _P.kwargs - ) -> Awaitable[ + + async def _async_call_impl( + current_context: Any, # noqa: ANN401 + *args: _P.args, + **kwargs: _P.kwargs, + ) -> ( CallResponse | Stream | _ResponseModelT | _ParsedOutputT | (_ResponseModelT | CallResponse) - ]: - # Capture the context at call time - current_context = get_current_context() - - # Define an async function that uses the captured context - async def context_bound_inner_async() -> ( - CallResponse - | Stream - | _ResponseModelT - | _ParsedOutputT - | (_ResponseModelT | CallResponse) - ): - # Apply any context overrides to the original call args - effective_call_args = apply_context_overrides_to_call_args( - original_call_args, context_override=current_context - ) + ): + # Apply any context overrides to the original call args + effective_call_args = apply_context_overrides_to_call_args( + original_call_args, context_override=current_context + ) - # Get the appropriate provider call function with the possibly overridden provider - effective_provider = effective_call_args["provider"] - effective_client = effective_call_args["client"] + # Get the appropriate provider call function with the possibly overridden provider + effective_provider = effective_call_args["provider"] + effective_client = effective_call_args["client"] - if effective_provider in get_args(LocalProvider): - provider_call, effective_client = _get_local_provider_call( - cast(LocalProvider, effective_provider), - effective_client, - True, - ) - effective_call_args["client"] = effective_client - else: - provider_call = _get_provider_call( - cast(Provider, effective_provider) - ) + if effective_provider in get_args(LocalProvider): + provider_call, effective_client = _get_local_provider_call( + cast(LocalProvider, effective_provider), + effective_client, + True, + ) + effective_call_args["client"] = effective_client + else: + provider_call = _get_provider_call( + cast(Provider, effective_provider) + ) - # Use the provider-specific call function with overridden args - call_kwargs = dict(effective_call_args) - del call_kwargs["provider"] # Not a parameter to provider_call + # Use the provider-specific call function with overridden args + call_kwargs = dict(effective_call_args) + del call_kwargs["provider"] # Not a parameter to provider_call - # Get decorated function using provider_call - decorated = provider_call(**call_kwargs)(fn) + # Get decorated function using provider_call + decorated = provider_call(**call_kwargs)(fn) - # Call the decorated function and wrap the result - result = await decorated(*args, **kwargs) - return _wrap_result(result) + # Call the decorated function and wrap the result + result = await decorated(*args, **kwargs) + return _wrap_result(result) - return context_bound_inner_async() + # Create a sync wrapper that captures context and returns a coroutine. + # Context is captured at call time (not await time) for asyncio.gather support. + @wraps(fn) + def wrapper_with_context( + *args: _P.args, **kwargs: _P.kwargs + ) -> Awaitable[ + CallResponse + | Stream + | _ResponseModelT + | _ParsedOutputT + | (_ResponseModelT | CallResponse) + ]: + # Capture context at call time (when coroutine is created) + current_context = get_current_context() + # Return the coroutine with the captured context + return _async_call_impl(current_context, *args, **kwargs) + # Mark for tenacity compatibility: tenacity checks __call__ for async + # This allows tenacity to detect this as an async callable + wrapper_with_context.__call__ = _async_call_impl # pyright: ignore [reportAttributeAccessIssue] wrapper_with_context._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue] wrapper_with_context._original_fn = fn # pyright: ignore [reportAttributeAccessIssue] diff --git a/tests/llm/test_call.py b/tests/llm/test_call.py index 3bc02aaef8..3d043d5ce0 100644 --- a/tests/llm/test_call.py +++ b/tests/llm/test_call.py @@ -583,3 +583,138 @@ async def dummy_async_function(): assert captured_args_list[1]["model"] == "claude-3-5-sonnet", ( "Context model override was not applied when using asyncio.gather" ) + + +@pytest.mark.asyncio +async def test_async_call_is_coroutine_callable(): + """Test that async decorated functions are properly detected as coroutine callables. + + This is important for compatibility with tenacity's @retry decorator, + which uses its is_coroutine_callable helper to determine if a function is async. + Tenacity checks the __call__ attribute for async functions. + """ + import functools + import inspect + + def is_coroutine_callable(call): + """Replicate tenacity's is_coroutine_callable logic.""" + if inspect.isclass(call): + return False + if inspect.iscoroutinefunction(call): + return True + partial_call = isinstance(call, functools.partial) and call.func + # This replicates tenacity's exact logic which checks __call__ attribute + dunder_call = partial_call or getattr(call, "__call__", None) # noqa: B004 + return inspect.iscoroutinefunction(dunder_call) + + def dummy_async_provider_call( + model, + stream, + tools, + response_model, + output_parser, + json_mode, + call_params, + client, + ): + def wrapper(fn): + async def inner(*args, **kwargs): + return ConcreteResponse( + metadata=Metadata(), + response={}, + tool_types=None, + prompt_template=None, + fn_args={}, + dynamic_config={}, + messages=[], + call_params=DummyCallParams(), + call_kwargs=BaseCallKwargs(), + user_message_param=None, + start_time=0, + end_time=0, + ) + + return inner + + return wrapper + + with patch( + "mirascope.llm._call._get_provider_call", + return_value=dummy_async_provider_call, + ): + + @call(provider="openai", model="gpt-4o-mini") + async def dummy_async_function(): ... + + # The decorated function should be recognized as a coroutine callable + # by tenacity's is_coroutine_callable helper + assert is_coroutine_callable(dummy_async_function), ( + "Async decorated function should be detected as a coroutine callable " + "for tenacity @retry compatibility" + ) + + # Verify it still works correctly + res = await dummy_async_function() + assert isinstance(res, CallResponse) + + +@pytest.mark.asyncio +async def test_tenacity_retry_with_async_call(): + """Test that tenacity's @retry decorator works with async call functions. + + This test verifies the fix for issue #989 where @retry with collect_errors + did not work with async methods. + """ + from tenacity import retry, stop_after_attempt + + call_count = 0 + + def dummy_async_provider_call( + model, + stream, + tools, + response_model, + output_parser, + json_mode, + call_params, + client, + ): + def wrapper(fn): + async def inner(*args, **kwargs): + nonlocal call_count + call_count += 1 + # Fail the first 2 attempts, succeed on the 3rd + if call_count < 3: + raise ValueError("Simulated failure") + return ConcreteResponse( + metadata=Metadata(), + response={}, + tool_types=None, + prompt_template=None, + fn_args={}, + dynamic_config={}, + messages=[], + call_params=DummyCallParams(), + call_kwargs=BaseCallKwargs(), + user_message_param=None, + start_time=0, + end_time=0, + ) + + return inner + + return wrapper + + with patch( + "mirascope.llm._call._get_provider_call", + return_value=dummy_async_provider_call, + ): + + @retry(stop=stop_after_attempt(3)) + @call(provider="openai", model="gpt-4o-mini") + async def retry_async_function(): ... + + # This should succeed on the 3rd attempt + res = await retry_async_function() + assert isinstance(res, CallResponse) + assert call_count == 3, "Function should have been called 3 times"