Skip to content

Commit 4325345

Browse files
committed
Enable streaming usage metrics for OpenAI providers
Inject stream_options for telemetry, add completion streaming metrics, fix params mutation, remove duplicate provider logic. Add unit tests.
1 parent b6ce242 commit 4325345

File tree

5 files changed

+216
-44
lines changed

5 files changed

+216
-44
lines changed

src/llama_stack/core/routers/inference.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,12 @@ async def openai_completion(
185185
params.model = provider_resource_id
186186

187187
if params.stream:
188-
return await provider.openai_completion(params)
189-
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
190-
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
188+
response_stream = await provider.openai_completion(params)
189+
return self.wrap_completion_stream_with_metrics(
190+
response=response_stream,
191+
fully_qualified_model_id=request_model_id,
192+
provider_id=provider.__provider_id__,
193+
)
191194

192195
response = await provider.openai_completion(params)
193196
response.model = request_model_id
@@ -412,16 +415,17 @@ async def stream_tokens_and_compute_metrics_openai_chat(
412415
completion_text += "".join(choice_data["content_parts"])
413416

414417
# Add metrics to the chunk
415-
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
416-
metrics = self._construct_metrics(
417-
prompt_tokens=chunk.usage.prompt_tokens,
418-
completion_tokens=chunk.usage.completion_tokens,
419-
total_tokens=chunk.usage.total_tokens,
420-
fully_qualified_model_id=fully_qualified_model_id,
421-
provider_id=provider_id,
422-
)
423-
for metric in metrics:
424-
enqueue_event(metric)
418+
if self.telemetry_enabled:
419+
if hasattr(chunk, "usage") and chunk.usage:
420+
metrics = self._construct_metrics(
421+
prompt_tokens=chunk.usage.prompt_tokens,
422+
completion_tokens=chunk.usage.completion_tokens,
423+
total_tokens=chunk.usage.total_tokens,
424+
fully_qualified_model_id=fully_qualified_model_id,
425+
provider_id=provider_id,
426+
)
427+
for metric in metrics:
428+
enqueue_event(metric)
425429

426430
yield chunk
427431
finally:
@@ -471,3 +475,31 @@ async def stream_tokens_and_compute_metrics_openai_chat(
471475
)
472476
logger.debug(f"InferenceRouter.completion_response: {final_response}")
473477
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
478+
479+
async def wrap_completion_stream_with_metrics(
480+
self,
481+
response: AsyncIterator,
482+
fully_qualified_model_id: str,
483+
provider_id: str,
484+
) -> AsyncIterator:
485+
"""Stream OpenAI completion chunks and compute metrics on final chunk."""
486+
487+
async for chunk in response:
488+
if hasattr(chunk, "model"):
489+
chunk.model = fully_qualified_model_id
490+
491+
if getattr(chunk, "choices", None) and any(c.finish_reason for c in chunk.choices):
492+
if self.telemetry_enabled:
493+
if getattr(chunk, "usage", None):
494+
usage = chunk.usage
495+
metrics = self._construct_metrics(
496+
prompt_tokens=usage.prompt_tokens,
497+
completion_tokens=usage.completion_tokens,
498+
total_tokens=usage.total_tokens,
499+
fully_qualified_model_id=fully_qualified_model_id,
500+
provider_id=provider_id,
501+
)
502+
for metric in metrics:
503+
enqueue_event(metric)
504+
505+
yield chunk

src/llama_stack/providers/remote/inference/runpod/runpod.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from collections.abc import AsyncIterator
8-
97
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
10-
from llama_stack_api import (
11-
OpenAIChatCompletion,
12-
OpenAIChatCompletionChunk,
13-
OpenAIChatCompletionRequestWithExtraBody,
14-
)
158

169
from .config import RunpodImplConfig
1710

@@ -29,15 +22,3 @@ class RunpodInferenceAdapter(OpenAIMixin):
2922
def get_base_url(self) -> str:
3023
"""Get base URL for OpenAI client."""
3124
return str(self.config.base_url)
32-
33-
async def openai_chat_completion(
34-
self,
35-
params: OpenAIChatCompletionRequestWithExtraBody,
36-
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
37-
"""Override to add RunPod-specific stream_options requirement."""
38-
params = params.model_copy()
39-
40-
if params.stream and not params.stream_options:
41-
params.stream_options = {"include_usage": True}
42-
43-
return await super().openai_chat_completion(params)

src/llama_stack/providers/remote/inference/watsonx/watsonx.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import litellm
1111
import requests
1212

13-
from llama_stack.core.telemetry.tracing import get_current_span
1413
from llama_stack.log import get_logger
1514
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
1615
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@@ -56,15 +55,6 @@ async def openai_chat_completion(
5655
Override parent method to add timeout and inject usage object when missing.
5756
This works around a LiteLLM defect where usage block is sometimes dropped.
5857
"""
59-
60-
# Add usage tracking for streaming when telemetry is active
61-
stream_options = params.stream_options
62-
if params.stream and get_current_span() is not None:
63-
if stream_options is None:
64-
stream_options = {"include_usage": True}
65-
elif "include_usage" not in stream_options:
66-
stream_options = {**stream_options, "include_usage": True}
67-
6858
model_obj = await self.model_store.get_model(params.model)
6959

7060
request_params = await prepare_openai_completion_params(
@@ -84,7 +74,7 @@ async def openai_chat_completion(
8474
seed=params.seed,
8575
stop=params.stop,
8676
stream=params.stream,
87-
stream_options=stream_options,
77+
stream_options=params.stream_options,
8878
temperature=params.temperature,
8979
tool_choice=params.tool_choice,
9080
tools=params.tools,

src/llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,16 @@ async def openai_completion(
258258
"""
259259
Direct OpenAI completion API call.
260260
"""
261+
from llama_stack.core.telemetry.tracing import get_current_span
262+
263+
# inject if streaming AND telemetry active
264+
if params.stream and get_current_span() is not None:
265+
params = params.model_copy()
266+
if params.stream_options is None:
267+
params.stream_options = {"include_usage": True}
268+
elif "include_usage" not in params.stream_options:
269+
params.stream_options = {**params.stream_options, "include_usage": True}
270+
261271
# TODO: fix openai_completion to return type compatible with OpenAI's API response
262272
completion_kwargs = await prepare_openai_completion_params(
263273
model=await self._get_provider_model_id(params.model),
@@ -292,6 +302,16 @@ async def openai_chat_completion(
292302
"""
293303
Direct OpenAI chat completion API call.
294304
"""
305+
from llama_stack.core.telemetry.tracing import get_current_span
306+
307+
# inject if streaming AND telemetry active
308+
if params.stream and get_current_span() is not None:
309+
params = params.model_copy()
310+
if params.stream_options is None:
311+
params.stream_options = {"include_usage": True}
312+
elif "include_usage" not in params.stream_options:
313+
params.stream_options = {**params.stream_options, "include_usage": True}
314+
295315
messages = params.messages
296316

297317
if self.download_images:

tests/unit/providers/utils/inference/test_openai_mixin.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
from llama_stack.core.request_headers import request_provider_data_context
1616
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
1717
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
18-
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
18+
from llama_stack_api import (
19+
Model,
20+
ModelType,
21+
OpenAIChatCompletionRequestWithExtraBody,
22+
OpenAICompletionRequestWithExtraBody,
23+
OpenAIUserMessageParam,
24+
)
1925

2026

2127
class OpenAIMixinImpl(OpenAIMixin):
@@ -834,3 +840,146 @@ def test_error_message_includes_correct_field_names(self, mixin_with_provider_da
834840
error_message = str(exc_info.value)
835841
assert "test_api_key" in error_message
836842
assert "x-llamastack-provider-data" in error_message
843+
844+
845+
class TestOpenAIMixinStreamingMetrics:
846+
"""Test cases for streaming metrics injection in OpenAIMixin"""
847+
848+
async def test_openai_chat_completion_streaming_metrics_injection(self, mixin, mock_client_context):
849+
"""Test that stream_options={"include_usage": True} is injected when streaming and telemetry is enabled"""
850+
851+
params = OpenAIChatCompletionRequestWithExtraBody(
852+
model="test-model",
853+
messages=[{"role": "user", "content": "hello"}],
854+
stream=True,
855+
stream_options=None,
856+
)
857+
858+
mock_client = MagicMock()
859+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
860+
861+
with mock_client_context(mixin, mock_client):
862+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
863+
mock_get_span.return_value = MagicMock()
864+
865+
with patch(
866+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
867+
) as mock_prepare:
868+
mock_prepare.return_value = {"model": "test-model"}
869+
870+
await mixin.openai_chat_completion(params)
871+
872+
call_kwargs = mock_prepare.call_args.kwargs
873+
assert call_kwargs["stream_options"] == {"include_usage": True}
874+
875+
assert params.stream_options is None
876+
877+
async def test_openai_chat_completion_streaming_no_telemetry(self, mixin, mock_client_context):
878+
"""Test that stream_options is NOT injected when telemetry is disabled"""
879+
880+
params = OpenAIChatCompletionRequestWithExtraBody(
881+
model="test-model",
882+
messages=[{"role": "user", "content": "hello"}],
883+
stream=True,
884+
stream_options=None,
885+
)
886+
887+
mock_client = MagicMock()
888+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
889+
890+
with mock_client_context(mixin, mock_client):
891+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
892+
mock_get_span.return_value = None
893+
894+
with patch(
895+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
896+
) as mock_prepare:
897+
mock_prepare.return_value = {"model": "test-model"}
898+
899+
await mixin.openai_chat_completion(params)
900+
901+
call_kwargs = mock_prepare.call_args.kwargs
902+
assert call_kwargs["stream_options"] is None
903+
904+
async def test_openai_completion_streaming_metrics_injection(self, mixin, mock_client_context):
905+
"""Test that stream_options={"include_usage": True} is injected for legacy completion"""
906+
907+
params = OpenAICompletionRequestWithExtraBody(
908+
model="test-model",
909+
prompt="hello",
910+
stream=True,
911+
stream_options=None,
912+
)
913+
914+
mock_client = MagicMock()
915+
mock_client.completions.create = AsyncMock(return_value=MagicMock())
916+
917+
with mock_client_context(mixin, mock_client):
918+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
919+
mock_get_span.return_value = MagicMock()
920+
921+
with patch(
922+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
923+
) as mock_prepare:
924+
mock_prepare.return_value = {"model": "test-model"}
925+
926+
await mixin.openai_completion(params)
927+
928+
call_kwargs = mock_prepare.call_args.kwargs
929+
assert call_kwargs["stream_options"] == {"include_usage": True}
930+
assert params.stream_options is None
931+
932+
async def test_preserves_existing_stream_options(self, mixin, mock_client_context):
933+
"""Test that existing stream_options are preserved and merged"""
934+
935+
params = OpenAIChatCompletionRequestWithExtraBody(
936+
model="test-model",
937+
messages=[{"role": "user", "content": "hello"}],
938+
stream=True,
939+
stream_options={"include_usage": False},
940+
)
941+
942+
mock_client = MagicMock()
943+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
944+
945+
with mock_client_context(mixin, mock_client):
946+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
947+
mock_get_span.return_value = MagicMock()
948+
949+
with patch(
950+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
951+
) as mock_prepare:
952+
mock_prepare.return_value = {"model": "test-model"}
953+
954+
await mixin.openai_chat_completion(params)
955+
956+
call_kwargs = mock_prepare.call_args.kwargs
957+
# It should stay False because it was present
958+
assert call_kwargs["stream_options"] == {"include_usage": False}
959+
960+
async def test_merges_existing_stream_options(self, mixin, mock_client_context):
961+
"""Test that existing stream_options are merged"""
962+
963+
params = OpenAIChatCompletionRequestWithExtraBody(
964+
model="test-model",
965+
messages=[{"role": "user", "content": "hello"}],
966+
stream=True,
967+
stream_options={"other_option": True},
968+
)
969+
970+
mock_client = MagicMock()
971+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
972+
973+
with mock_client_context(mixin, mock_client):
974+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
975+
mock_get_span.return_value = MagicMock()
976+
977+
with patch(
978+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
979+
) as mock_prepare:
980+
mock_prepare.return_value = {"model": "test-model"}
981+
982+
await mixin.openai_chat_completion(params)
983+
984+
call_kwargs = mock_prepare.call_args.kwargs
985+
assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True}

0 commit comments

Comments
 (0)