Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,29 @@ class _MessagesStream(ObjectProxy): # type: ignore
"_response_accumulator",
"_with_span",
"_is_finished",
"_manager",
)

def __init__(
self,
stream: "Stream[RawMessageStreamEvent]",
with_span: _WithSpan,
manager: Any = None,
) -> None:
super().__init__(stream)
self._response_accumulator = _MessageResponseAccumulator()
self._with_span = with_span
self._manager = manager

def __exit__(
self,
exc_type: Any,
exc_val: Any,
exc_tb: Any,
) -> None:
# Delegate to the manager's __exit__ to ensure proper cleanup
if self._manager is not None and hasattr(self._manager, "__exit__"):
self._manager.__exit__(exc_type, exc_val, exc_tb)

def __iter__(self) -> Iterator["RawMessageStreamEvent"]:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
if TYPE_CHECKING:
from pydantic import BaseModel

from anthropic.lib.streaming import MessageStreamManager
from anthropic.lib.streaming import MessageStream, MessageStreamManager
from anthropic.types import Message, Usage


Expand Down Expand Up @@ -351,9 +351,9 @@ def __init__(
super().__init__(manager)
self._self_with_span = with_span

def __enter__(self) -> Iterator[str]:
raw = self.__api_request()
return _MessagesStream(raw, self._self_with_span)
def __enter__(self) -> "MessageStream":
message_stream = self.__wrapped__.__enter__()
return _MessagesStream(message_stream, self._self_with_span, self.__wrapped__)


def _get_inputs(arguments: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
interactions:
- request:
body: '{"max_tokens": 1024, "messages": [{"role": "user", "content": "What''s
the capital of France?"}], "model": "claude-3-opus-latest", "stream": true}'
headers: {}
method: POST
uri: https://api.anthropic.com/v1/messages
response:
body:
string: 'event: message_start

data: {"type":"message_start","message":{"id":"msg_01EdTbzEsQHdxkVoFKSAFGUS","type":"message","role":"assistant","model":"claude-3-opus-latest","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":14,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":4}} }


event: content_block_start

data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }


event: ping

data: {"type": "ping"}


event: content_block_delta

data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"The
capital of France"} }


event: content_block_delta

data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"
is Paris."} }


event: content_block_stop

data: {"type":"content_block_stop","index":0 }


event: message_delta

data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":10} }


event: message_stop

data: {"type":"message_stop" }


'
headers: {}
status:
code: 200
message: OK
version: 1
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def test_anthropic_instrumentation_stream_message(
) as stream:
for _ in stream:
pass
# Test that get_final_message() works (fixes #2467)
final_message = stream.get_final_message()
assert final_message is not None
assert hasattr(final_message, 'content')
assert hasattr(final_message, 'usage')

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
Expand Down Expand Up @@ -232,6 +237,63 @@ def test_anthropic_instrumentation_stream_message(
assert not attributes


@pytest.mark.vcr(
decode_compressed_response=True,
before_record_request=remove_all_vcr_request_headers,
before_record_response=remove_all_vcr_response_headers,
)
def test_anthropic_instrumentation_stream_context_manager_exit(
tracer_provider: TracerProvider,
in_memory_span_exporter: InMemorySpanExporter,
setup_anthropic_instrumentation: Any,
) -> None:
"""Test that __exit__() is properly called on the MessageStreamManager when exiting context."""
from unittest.mock import patch

client = Anthropic(api_key="fake")

invocation_params = {
"model": "claude-3-opus-latest",
"max_tokens": 1024,
}

# Track if __exit__ was called on the wrapped manager
exit_called = False

# Store the original __exit__ method before patching
original_exit = anthropic.lib.streaming.MessageStreamManager.__exit__

def mock_exit(self: Any, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
nonlocal exit_called
exit_called = True
# Call the original __exit__ to ensure proper cleanup
return original_exit(self, exc_type, exc_val, exc_tb)

# Patch MessageStreamManager.__exit__ to track if it's called
with patch.object(
anthropic.lib.streaming.MessageStreamManager,
"__exit__",
side_effect=mock_exit,
autospec=True,
):
# Use the context manager pattern
with client.messages.stream(
**invocation_params,
messages=[
{"role": "user", "content": "Say hello in one word"},
],
) as stream:
# Consume the stream
for text in stream.text_stream:
pass
# get_final_message should work
final_message = stream.get_final_message()
assert final_message is not None

# Verify that __exit__ was actually called
assert exit_called, "MessageStreamManager.__exit__() was not called"


@pytest.mark.asyncio
@pytest.mark.vcr(
decode_compressed_response=True,
Expand Down