Skip to content

Commit c81784c

Browse files
committed
Enable streaming usage metrics for OpenAI providers
Inject stream_options for telemetry, add completion streaming metrics, fix params mutation, remove duplicate provider logic. Add unit tests.
1 parent acf74cb commit c81784c

File tree

5 files changed

+209
-43
lines changed

5 files changed

+209
-43
lines changed

src/llama_stack/core/routers/inference.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,12 @@ async def openai_completion(
185185
params.model = provider_resource_id
186186

187187
if params.stream:
188-
return await provider.openai_completion(params)
189-
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
190-
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
188+
response_stream = await provider.openai_completion(params)
189+
return self.wrap_completion_stream_with_metrics(
190+
response=response_stream,
191+
fully_qualified_model_id=request_model_id,
192+
provider_id=provider.__provider_id__,
193+
)
191194

192195
response = await provider.openai_completion(params)
193196
response.model = request_model_id
@@ -412,16 +415,17 @@ async def stream_tokens_and_compute_metrics_openai_chat(
412415
completion_text += "".join(choice_data["content_parts"])
413416

414417
# Add metrics to the chunk
415-
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
416-
metrics = self._construct_metrics(
417-
prompt_tokens=chunk.usage.prompt_tokens,
418-
completion_tokens=chunk.usage.completion_tokens,
419-
total_tokens=chunk.usage.total_tokens,
420-
fully_qualified_model_id=fully_qualified_model_id,
421-
provider_id=provider_id,
422-
)
423-
for metric in metrics:
424-
enqueue_event(metric)
418+
if self.telemetry_enabled:
419+
if hasattr(chunk, "usage") and chunk.usage:
420+
metrics = self._construct_metrics(
421+
prompt_tokens=chunk.usage.prompt_tokens,
422+
completion_tokens=chunk.usage.completion_tokens,
423+
total_tokens=chunk.usage.total_tokens,
424+
fully_qualified_model_id=fully_qualified_model_id,
425+
provider_id=provider_id,
426+
)
427+
for metric in metrics:
428+
enqueue_event(metric)
425429

426430
yield chunk
427431
finally:
@@ -471,3 +475,31 @@ async def stream_tokens_and_compute_metrics_openai_chat(
471475
)
472476
logger.debug(f"InferenceRouter.completion_response: {final_response}")
473477
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
478+
479+
async def wrap_completion_stream_with_metrics(
480+
self,
481+
response: AsyncIterator,
482+
fully_qualified_model_id: str,
483+
provider_id: str,
484+
) -> AsyncIterator:
485+
"""Stream OpenAI completion chunks and compute metrics on final chunk."""
486+
487+
async for chunk in response:
488+
if hasattr(chunk, "model"):
489+
chunk.model = fully_qualified_model_id
490+
491+
if getattr(chunk, "choices", None) and any(c.finish_reason for c in chunk.choices):
492+
if self.telemetry_enabled:
493+
if getattr(chunk, "usage", None):
494+
usage = chunk.usage
495+
metrics = self._construct_metrics(
496+
prompt_tokens=usage.prompt_tokens,
497+
completion_tokens=usage.completion_tokens,
498+
total_tokens=usage.total_tokens,
499+
fully_qualified_model_id=fully_qualified_model_id,
500+
provider_id=provider_id,
501+
)
502+
for metric in metrics:
503+
enqueue_event(metric)
504+
505+
yield chunk

src/llama_stack/providers/remote/inference/runpod/runpod.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from collections.abc import AsyncIterator
8-
97
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
10-
from llama_stack_api import (
11-
OpenAIChatCompletion,
12-
OpenAIChatCompletionChunk,
13-
OpenAIChatCompletionRequestWithExtraBody,
14-
)
158

169
from .config import RunpodImplConfig
1710

@@ -29,15 +22,3 @@ class RunpodInferenceAdapter(OpenAIMixin):
2922
def get_base_url(self) -> str:
3023
"""Get base URL for OpenAI client."""
3124
return str(self.config.base_url)
32-
33-
async def openai_chat_completion(
34-
self,
35-
params: OpenAIChatCompletionRequestWithExtraBody,
36-
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
37-
"""Override to add RunPod-specific stream_options requirement."""
38-
params = params.model_copy()
39-
40-
if params.stream and not params.stream_options:
41-
params.stream_options = {"include_usage": True}
42-
43-
return await super().openai_chat_completion(params)

src/llama_stack/providers/remote/inference/watsonx/watsonx.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import litellm
1111
import requests
1212

13-
from llama_stack.core.telemetry.tracing import get_current_span
1413
from llama_stack.log import get_logger
1514
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
1615
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@@ -56,15 +55,6 @@ async def openai_chat_completion(
5655
Override parent method to add timeout and inject usage object when missing.
5756
This works around a LiteLLM defect where usage block is sometimes dropped.
5857
"""
59-
60-
# Add usage tracking for streaming when telemetry is active
61-
stream_options = params.stream_options
62-
if params.stream and get_current_span() is not None:
63-
if stream_options is None:
64-
stream_options = {"include_usage": True}
65-
elif "include_usage" not in stream_options:
66-
stream_options = {**stream_options, "include_usage": True}
67-
6858
model_obj = await self.model_store.get_model(params.model)
6959

7060
request_params = await prepare_openai_completion_params(
@@ -84,7 +74,7 @@ async def openai_chat_completion(
8474
seed=params.seed,
8575
stop=params.stop,
8676
stream=params.stream,
87-
stream_options=stream_options,
77+
stream_options=params.stream_options,
8878
temperature=params.temperature,
8979
tool_choice=params.tool_choice,
9080
tools=params.tools,

src/llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,16 @@ async def openai_completion(
271271
"""
272272
Direct OpenAI completion API call.
273273
"""
274+
from llama_stack.core.telemetry.tracing import get_current_span
275+
276+
# inject if streaming AND telemetry active
277+
if params.stream and get_current_span() is not None:
278+
params = params.model_copy()
279+
if params.stream_options is None:
280+
params.stream_options = {"include_usage": True}
281+
elif "include_usage" not in params.stream_options:
282+
params.stream_options = {**params.stream_options, "include_usage": True}
283+
274284
# TODO: fix openai_completion to return type compatible with OpenAI's API response
275285
provider_model_id = await self._get_provider_model_id(params.model)
276286
self._validate_model_allowed(provider_model_id)
@@ -308,6 +318,16 @@ async def openai_chat_completion(
308318
"""
309319
Direct OpenAI chat completion API call.
310320
"""
321+
from llama_stack.core.telemetry.tracing import get_current_span
322+
323+
# inject if streaming AND telemetry active
324+
if params.stream and get_current_span() is not None:
325+
params = params.model_copy()
326+
if params.stream_options is None:
327+
params.stream_options = {"include_usage": True}
328+
elif "include_usage" not in params.stream_options:
329+
params.stream_options = {**params.stream_options, "include_usage": True}
330+
311331
provider_model_id = await self._get_provider_model_id(params.model)
312332
self._validate_model_allowed(provider_model_id)
313333

tests/unit/providers/utils/inference/test_openai_mixin.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,3 +934,146 @@ async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
934934
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
935935
)
936936
)
937+
938+
939+
class TestOpenAIMixinStreamingMetrics:
940+
"""Test cases for streaming metrics injection in OpenAIMixin"""
941+
942+
async def test_openai_chat_completion_streaming_metrics_injection(self, mixin, mock_client_context):
943+
"""Test that stream_options={"include_usage": True} is injected when streaming and telemetry is enabled"""
944+
945+
params = OpenAIChatCompletionRequestWithExtraBody(
946+
model="test-model",
947+
messages=[{"role": "user", "content": "hello"}],
948+
stream=True,
949+
stream_options=None,
950+
)
951+
952+
mock_client = MagicMock()
953+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
954+
955+
with mock_client_context(mixin, mock_client):
956+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
957+
mock_get_span.return_value = MagicMock()
958+
959+
with patch(
960+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
961+
) as mock_prepare:
962+
mock_prepare.return_value = {"model": "test-model"}
963+
964+
await mixin.openai_chat_completion(params)
965+
966+
call_kwargs = mock_prepare.call_args.kwargs
967+
assert call_kwargs["stream_options"] == {"include_usage": True}
968+
969+
assert params.stream_options is None
970+
971+
async def test_openai_chat_completion_streaming_no_telemetry(self, mixin, mock_client_context):
972+
"""Test that stream_options is NOT injected when telemetry is disabled"""
973+
974+
params = OpenAIChatCompletionRequestWithExtraBody(
975+
model="test-model",
976+
messages=[{"role": "user", "content": "hello"}],
977+
stream=True,
978+
stream_options=None,
979+
)
980+
981+
mock_client = MagicMock()
982+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
983+
984+
with mock_client_context(mixin, mock_client):
985+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
986+
mock_get_span.return_value = None
987+
988+
with patch(
989+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
990+
) as mock_prepare:
991+
mock_prepare.return_value = {"model": "test-model"}
992+
993+
await mixin.openai_chat_completion(params)
994+
995+
call_kwargs = mock_prepare.call_args.kwargs
996+
assert call_kwargs["stream_options"] is None
997+
998+
async def test_openai_completion_streaming_metrics_injection(self, mixin, mock_client_context):
999+
"""Test that stream_options={"include_usage": True} is injected for legacy completion"""
1000+
1001+
params = OpenAICompletionRequestWithExtraBody(
1002+
model="test-model",
1003+
prompt="hello",
1004+
stream=True,
1005+
stream_options=None,
1006+
)
1007+
1008+
mock_client = MagicMock()
1009+
mock_client.completions.create = AsyncMock(return_value=MagicMock())
1010+
1011+
with mock_client_context(mixin, mock_client):
1012+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
1013+
mock_get_span.return_value = MagicMock()
1014+
1015+
with patch(
1016+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
1017+
) as mock_prepare:
1018+
mock_prepare.return_value = {"model": "test-model"}
1019+
1020+
await mixin.openai_completion(params)
1021+
1022+
call_kwargs = mock_prepare.call_args.kwargs
1023+
assert call_kwargs["stream_options"] == {"include_usage": True}
1024+
assert params.stream_options is None
1025+
1026+
async def test_preserves_existing_stream_options(self, mixin, mock_client_context):
1027+
"""Test that existing stream_options are preserved and merged"""
1028+
1029+
params = OpenAIChatCompletionRequestWithExtraBody(
1030+
model="test-model",
1031+
messages=[{"role": "user", "content": "hello"}],
1032+
stream=True,
1033+
stream_options={"include_usage": False},
1034+
)
1035+
1036+
mock_client = MagicMock()
1037+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
1038+
1039+
with mock_client_context(mixin, mock_client):
1040+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
1041+
mock_get_span.return_value = MagicMock()
1042+
1043+
with patch(
1044+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
1045+
) as mock_prepare:
1046+
mock_prepare.return_value = {"model": "test-model"}
1047+
1048+
await mixin.openai_chat_completion(params)
1049+
1050+
call_kwargs = mock_prepare.call_args.kwargs
1051+
# It should stay False because it was present
1052+
assert call_kwargs["stream_options"] == {"include_usage": False}
1053+
1054+
async def test_merges_existing_stream_options(self, mixin, mock_client_context):
1055+
"""Test that existing stream_options are merged"""
1056+
1057+
params = OpenAIChatCompletionRequestWithExtraBody(
1058+
model="test-model",
1059+
messages=[{"role": "user", "content": "hello"}],
1060+
stream=True,
1061+
stream_options={"other_option": True},
1062+
)
1063+
1064+
mock_client = MagicMock()
1065+
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
1066+
1067+
with mock_client_context(mixin, mock_client):
1068+
with patch("llama_stack.core.telemetry.tracing.get_current_span") as mock_get_span:
1069+
mock_get_span.return_value = MagicMock()
1070+
1071+
with patch(
1072+
"llama_stack.providers.utils.inference.openai_mixin.prepare_openai_completion_params"
1073+
) as mock_prepare:
1074+
mock_prepare.return_value = {"model": "test-model"}
1075+
1076+
await mixin.openai_chat_completion(params)
1077+
1078+
call_kwargs = mock_prepare.call_args.kwargs
1079+
assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True}

0 commit comments

Comments
 (0)