|
15 | 15 | from llama_stack.core.request_headers import request_provider_data_context |
16 | 16 | from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig |
17 | 17 | from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin |
18 | | -from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam |
| 18 | +from llama_stack_api import ( |
| 19 | + Model, |
| 20 | + ModelType, |
| 21 | + OpenAIChatCompletionRequestWithExtraBody, |
| 22 | + OpenAICompletionRequestWithExtraBody, |
| 23 | + OpenAIUserMessageParam, |
| 24 | +) |
19 | 25 |
|
20 | 26 |
|
21 | 27 | class OpenAIMixinImpl(OpenAIMixin): |
@@ -834,3 +840,146 @@ def test_error_message_includes_correct_field_names(self, mixin_with_provider_da |
834 | 840 | error_message = str(exc_info.value) |
835 | 841 | assert "test_api_key" in error_message |
836 | 842 | assert "x-llamastack-provider-data" in error_message |
| 843 | + |
| 844 | + |
| 845 | +class TestOpenAIMixinStreamingMetrics: |
| 846 | + """Test cases for streaming metrics injection in OpenAIMixin""" |
| 847 | + |
| 848 | + async def test_openai_chat_completion_streaming_metrics_injection(self, mixin, mock_client_context): |
| 849 | + """Test that stream_options={"include_usage": True} is injected when streaming and telemetry is enabled""" |
| 850 | + |
| 851 | + params = OpenAIChatCompletionRequestWithExtraBody( |
| 852 | + model="test-model", |
| 853 | + messages=[{"role": "user", "content": "hello"}], |
| 854 | + stream=True, |
| 855 | + stream_options=None, |
| 856 | + ) |
| 857 | + |
| 858 | + mock_client = MagicMock() |
| 859 | + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) |
| 860 | + |
| 861 | + with mock_client_context(mixin, mock_client): |
| 862 | + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: |
| 863 | + mock_get_span.return_value = MagicMock() |
| 864 | + |
| 865 | + with patch( |
| 866 | + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" |
| 867 | + ) as mock_prepare: |
| 868 | + mock_prepare.return_value = {"model": "test-model"} |
| 869 | + |
| 870 | + await mixin.openai_chat_completion(params) |
| 871 | + |
| 872 | + call_kwargs = mock_prepare.call_args.kwargs |
| 873 | + assert call_kwargs["stream_options"] == {"include_usage": True} |
| 874 | + |
| 875 | + assert params.stream_options is None |
| 876 | + |
| 877 | + async def test_openai_chat_completion_streaming_no_telemetry(self, mixin, mock_client_context): |
| 878 | + """Test that stream_options is NOT injected when telemetry is disabled""" |
| 879 | + |
| 880 | + params = OpenAIChatCompletionRequestWithExtraBody( |
| 881 | + model="test-model", |
| 882 | + messages=[{"role": "user", "content": "hello"}], |
| 883 | + stream=True, |
| 884 | + stream_options=None, |
| 885 | + ) |
| 886 | + |
| 887 | + mock_client = MagicMock() |
| 888 | + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) |
| 889 | + |
| 890 | + with mock_client_context(mixin, mock_client): |
| 891 | + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: |
| 892 | + mock_get_span.return_value = None |
| 893 | + |
| 894 | + with patch( |
| 895 | + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" |
| 896 | + ) as mock_prepare: |
| 897 | + mock_prepare.return_value = {"model": "test-model"} |
| 898 | + |
| 899 | + await mixin.openai_chat_completion(params) |
| 900 | + |
| 901 | + call_kwargs = mock_prepare.call_args.kwargs |
| 902 | + assert call_kwargs["stream_options"] is None |
| 903 | + |
| 904 | + async def test_openai_completion_streaming_metrics_injection(self, mixin, mock_client_context): |
| 905 | + """Test that stream_options={"include_usage": True} is injected for legacy completion""" |
| 906 | + |
| 907 | + params = OpenAICompletionRequestWithExtraBody( |
| 908 | + model="test-model", |
| 909 | + prompt="hello", |
| 910 | + stream=True, |
| 911 | + stream_options=None, |
| 912 | + ) |
| 913 | + |
| 914 | + mock_client = MagicMock() |
| 915 | + mock_client.completions.create = AsyncMock(return_value=MagicMock()) |
| 916 | + |
| 917 | + with mock_client_context(mixin, mock_client): |
| 918 | + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: |
| 919 | + mock_get_span.return_value = MagicMock() |
| 920 | + |
| 921 | + with patch( |
| 922 | + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" |
| 923 | + ) as mock_prepare: |
| 924 | + mock_prepare.return_value = {"model": "test-model"} |
| 925 | + |
| 926 | + await mixin.openai_completion(params) |
| 927 | + |
| 928 | + call_kwargs = mock_prepare.call_args.kwargs |
| 929 | + assert call_kwargs["stream_options"] == {"include_usage": True} |
| 930 | + assert params.stream_options is None |
| 931 | + |
| 932 | + async def test_preserves_existing_stream_options(self, mixin, mock_client_context): |
| 933 | + """Test that existing stream_options are preserved and merged""" |
| 934 | + |
| 935 | + params = OpenAIChatCompletionRequestWithExtraBody( |
| 936 | + model="test-model", |
| 937 | + messages=[{"role": "user", "content": "hello"}], |
| 938 | + stream=True, |
| 939 | + stream_options={"include_usage": False}, |
| 940 | + ) |
| 941 | + |
| 942 | + mock_client = MagicMock() |
| 943 | + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) |
| 944 | + |
| 945 | + with mock_client_context(mixin, mock_client): |
| 946 | + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: |
| 947 | + mock_get_span.return_value = MagicMock() |
| 948 | + |
| 949 | + with patch( |
| 950 | + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" |
| 951 | + ) as mock_prepare: |
| 952 | + mock_prepare.return_value = {"model": "test-model"} |
| 953 | + |
| 954 | + await mixin.openai_chat_completion(params) |
| 955 | + |
| 956 | + call_kwargs = mock_prepare.call_args.kwargs |
| 957 | + # It should stay False because it was present |
| 958 | + assert call_kwargs["stream_options"] == {"include_usage": False} |
| 959 | + |
| 960 | + async def test_merges_existing_stream_options(self, mixin, mock_client_context): |
| 961 | + """Test that existing stream_options are merged""" |
| 962 | + |
| 963 | + params = OpenAIChatCompletionRequestWithExtraBody( |
| 964 | + model="test-model", |
| 965 | + messages=[{"role": "user", "content": "hello"}], |
| 966 | + stream=True, |
| 967 | + stream_options={"other_option": True}, |
| 968 | + ) |
| 969 | + |
| 970 | + mock_client = MagicMock() |
| 971 | + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) |
| 972 | + |
| 973 | + with mock_client_context(mixin, mock_client): |
| 974 | + with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span: |
| 975 | + mock_get_span.return_value = MagicMock() |
| 976 | + |
| 977 | + with patch( |
| 978 | + "llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params" |
| 979 | + ) as mock_prepare: |
| 980 | + mock_prepare.return_value = {"model": "test-model"} |
| 981 | + |
| 982 | + await mixin.openai_chat_completion(params) |
| 983 | + |
| 984 | + call_kwargs = mock_prepare.call_args.kwargs |
| 985 | + assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True} |
0 commit comments