diff --git a/langfuse/_client/observe.py b/langfuse/_client/observe.py index e8786a0e0..195dd4404 100644 --- a/langfuse/_client/observe.py +++ b/langfuse/_client/observe.py @@ -6,10 +6,8 @@ from functools import wraps from typing import ( Any, - AsyncGenerator, Callable, Dict, - Generator, Iterable, List, Optional, @@ -19,6 +17,7 @@ cast, overload, ) +from collections.abc import AsyncIterator, Iterator from opentelemetry.util._decorator import _AgnosticContextManager from typing_extensions import ParamSpec @@ -278,7 +277,7 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any: as_type=as_type or "span", trace_context=trace_context, input=input, - end_on_exit=False, # when returning a generator, closing on exit would be to early + end_on_exit=False, # when returning a iterator, closing on exit would be to early ) if langfuse_client else None @@ -288,25 +287,25 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any: return await func(*args, **kwargs) with context_manager as langfuse_span_or_generation: - is_return_type_generator = False + is_return_type_iterator = False try: result = await func(*args, **kwargs) if capture_output is True: - if inspect.isgenerator(result): - is_return_type_generator = True + if isinstance(result, Iterator): + is_return_type_iterator = True - return self._wrap_sync_generator_result( + return self._wrap_sync_iterator_result( langfuse_span_or_generation, result, transform_to_string, ) - if inspect.isasyncgen(result): - is_return_type_generator = True + if isinstance(result, AsyncIterator): + is_return_type_iterator = True - return self._wrap_async_generator_result( + return self._wrap_async_iterator_result( langfuse_span_or_generation, result, transform_to_string, @@ -316,14 +315,12 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any: if type(result).__name__ == "StreamingResponse" and hasattr( result, "body_iterator" ): - is_return_type_generator = True - - result.body_iterator = ( - self._wrap_async_generator_result( - langfuse_span_or_generation, - result.body_iterator, - transform_to_string, - ) + is_return_type_iterator = True + + result.body_iterator = self._wrap_async_iterator_result( + langfuse_span_or_generation, + result.body_iterator, + transform_to_string, ) langfuse_span_or_generation.update(output=result) @@ -336,7 +333,7 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any: raise e finally: - if not is_return_type_generator: + if not is_return_type_iterator: langfuse_span_or_generation.end() return cast(F, async_wrapper) @@ -396,7 +393,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: as_type=as_type or "span", trace_context=trace_context, input=input, - end_on_exit=False, # when returning a generator, closing on exit would be to early + end_on_exit=False, # when returning a iterator, closing on exit would be to early ) if langfuse_client else None @@ -406,25 +403,25 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) with context_manager as langfuse_span_or_generation: - is_return_type_generator = False + is_return_type_iterator = False try: result = func(*args, **kwargs) if capture_output is True: - if inspect.isgenerator(result): - is_return_type_generator = True + if isinstance(result, Iterator): + is_return_type_iterator = True - return self._wrap_sync_generator_result( + return self._wrap_sync_iterator_result( langfuse_span_or_generation, result, transform_to_string, ) - if inspect.isasyncgen(result): - is_return_type_generator = True + if isinstance(result, AsyncIterator): + is_return_type_iterator = True - return self._wrap_async_generator_result( + return self._wrap_async_iterator_result( langfuse_span_or_generation, result, transform_to_string, @@ -434,14 +431,12 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: if type(result).__name__ == "StreamingResponse" and hasattr( result, "body_iterator" ): - is_return_type_generator = True - - result.body_iterator = ( - self._wrap_async_generator_result( - langfuse_span_or_generation, - result.body_iterator, - transform_to_string, - ) + is_return_type_iterator = True + + result.body_iterator = self._wrap_async_iterator_result( + langfuse_span_or_generation, + result.body_iterator, + transform_to_string, ) langfuse_span_or_generation.update(output=result) @@ -454,7 +449,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: raise e finally: - if not is_return_type_generator: + if not is_return_type_iterator: langfuse_span_or_generation.end() return cast(F, sync_wrapper) @@ -481,7 +476,7 @@ def _get_input_from_func_args( "kwargs": func_kwargs, } - def _wrap_sync_generator_result( + def _wrap_sync_iterator_result( self, langfuse_span_or_generation: Union[ LangfuseSpan, @@ -494,19 +489,19 @@ def _wrap_sync_generator_result( LangfuseEmbedding, LangfuseGuardrail, ], - generator: Generator, + iterator: Iterator, transform_to_string: Optional[Callable[[Iterable], str]] = None, ) -> Any: preserved_context = contextvars.copy_context() - return _ContextPreservedSyncGeneratorWrapper( - generator, + return _ContextPreservedSyncIteratorWrapper( + iterator, preserved_context, langfuse_span_or_generation, transform_to_string, ) - def _wrap_async_generator_result( + def _wrap_async_iterator_result( self, langfuse_span_or_generation: Union[ LangfuseSpan, @@ -519,13 +514,13 @@ def _wrap_async_generator_result( LangfuseEmbedding, LangfuseGuardrail, ], - generator: AsyncGenerator, + iterator: AsyncIterator, transform_to_string: Optional[Callable[[Iterable], str]] = None, ) -> Any: preserved_context = contextvars.copy_context() - return _ContextPreservedAsyncGeneratorWrapper( - generator, + return _ContextPreservedAsyncIteratorWrapper( + iterator, preserved_context, langfuse_span_or_generation, transform_to_string, @@ -537,12 +532,12 @@ def _wrap_async_generator_result( observe = _decorator.observe -class _ContextPreservedSyncGeneratorWrapper: - """Sync generator wrapper that ensures each iteration runs in preserved context.""" +class _ContextPreservedSyncIteratorWrapper: + """Sync iterator wrapper that ensures each iteration runs in preserved context.""" def __init__( self, - generator: Generator, + iterator: Iterator, context: contextvars.Context, span: Union[ LangfuseSpan, @@ -557,25 +552,25 @@ def __init__( ], transform_fn: Optional[Callable[[Iterable], str]], ) -> None: - self.generator = generator + self.iterator = iterator self.context = context self.items: List[Any] = [] self.span = span self.transform_fn = transform_fn - def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper": + def __iter__(self) -> "_ContextPreservedSyncIteratorWrapper": return self def __next__(self) -> Any: try: - # Run the generator's __next__ in the preserved context - item = self.context.run(next, self.generator) + # Run the iterator's __next__ in the preserved context + item = self.context.run(next, self.iterator) self.items.append(item) return item except StopIteration: - # Handle output and span cleanup when generator is exhausted + # Handle output and span cleanup when iterator is exhausted output: Any = self.items if self.transform_fn is not None: @@ -596,12 +591,12 @@ def __next__(self) -> Any: raise -class _ContextPreservedAsyncGeneratorWrapper: - """Async generator wrapper that ensures each iteration runs in preserved context.""" +class _ContextPreservedAsyncIteratorWrapper: + """Async iterator wrapper that ensures each iteration runs in preserved context.""" def __init__( self, - generator: AsyncGenerator, + iterator: AsyncIterator, context: contextvars.Context, span: Union[ LangfuseSpan, @@ -616,34 +611,34 @@ def __init__( ], transform_fn: Optional[Callable[[Iterable], str]], ) -> None: - self.generator = generator + self.iterator = iterator self.context = context self.items: List[Any] = [] self.span = span self.transform_fn = transform_fn - def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper": + def __aiter__(self) -> "_ContextPreservedAsyncIteratorWrapper": return self async def __anext__(self) -> Any: try: - # Run the generator's __anext__ in the preserved context + # Run the iterator's __anext__ in the preserved context try: # Python 3.10+ approach with context parameter item = await asyncio.create_task( - self.generator.__anext__(), # type: ignore + self.iterator.__anext__(), # type: ignore context=self.context, ) # type: ignore except TypeError: # Python < 3.10 fallback - context parameter not supported - item = await self.generator.__anext__() + item = await self.iterator.__anext__() self.items.append(item) return item except StopAsyncIteration: - # Handle output and span cleanup when generator is exhausted + # Handle output and span cleanup when iterator is exhausted output: Any = self.items if self.transform_fn is not None: diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 0c82c1a6f..d05b91013 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -1751,10 +1751,46 @@ def root_function(): assert generator_obs.output == "item_0item_1item_2" +@pytest.fixture(params=["generator", "iterator"]) +def async_iterable_factory(request): + """Factory that creates either an async generator or async iterator""" + iterable_type = request.param + + if iterable_type == "generator": + + async def create_async_generator(): + for i in range(3): + await asyncio.sleep(0.001) + yield f"async_item_{i}" + + return create_async_generator + else: # iterator + + class AIter: + def __init__(self): + self.index = -1 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index < 2: + await asyncio.sleep(0.001) + self.index += 1 + return f"async_item_{self.index}" + else: + raise StopAsyncIteration + + def create_async_iterator(): + return AIter() + + return create_async_iterator + + @pytest.mark.asyncio @pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11 or higher") -async def test_async_generator_context_preservation(): - """Test that async generators preserve context when consumed later (e.g., by streaming responses)""" +async def test_async_generator_context_preservation(async_iterable_factory): + """Test that async generators and iterators preserve context when consumed later (e.g., by streaming responses)""" langfuse = get_client() mock_trace_id = langfuse.create_trace_id() @@ -1762,15 +1798,13 @@ async def test_async_generator_context_preservation(): span_info = {} @observe(name="async_generator") - async def create_async_generator(): + async def create_async_iterable(): current_span = trace.get_current_span() span_info["generator_span_id"] = trace.format_span_id( current_span.get_span_context().span_id ) - for i in range(3): - await asyncio.sleep(0.001) # Simulate async work - yield f"async_item_{i}" + return async_iterable_factory() @observe(name="root") async def root_function(): @@ -1779,15 +1813,15 @@ async def root_function(): current_span.get_span_context().span_id ) - # Return generator without consuming it (like FastAPI StreamingResponse would) - return create_async_generator() + # Return iterable without consuming it (like FastAPI StreamingResponse would) + return await create_async_iterable() - # Simulate the scenario where generator is consumed after root function exits - generator = await root_function(langfuse_trace_id=mock_trace_id) + # Simulate the scenario where iterable is consumed after root function exits + iterable = await root_function(langfuse_trace_id=mock_trace_id) - # Consume generator later (like FastAPI would) + # Consume iterable later (like FastAPI would) items = [] - async for item in generator: + async for item in iterable: items.append(item) langfuse.flush() @@ -1795,7 +1829,7 @@ async def root_function(): # Verify results assert items == ["async_item_0", "async_item_1", "async_item_2"] assert span_info["generator_span_id"] != "0000000000000000", ( - "Generator context should be preserved" + "Context should be preserved" ) assert span_info["root_span_id"] != span_info["generator_span_id"], ( "Should have different span IDs" @@ -1810,7 +1844,7 @@ async def root_function(): assert "root" in observation_names assert "async_generator" in observation_names - # Verify generator observation has output + # Verify observation has output generator_obs = next( obs for obs in trace_data.observations if obs.name == "async_generator" )