diff --git a/langfuse/_task_manager/ingestion_consumer.py b/langfuse/_task_manager/ingestion_consumer.py index 9900654c..8648eabf 100644 --- a/langfuse/_task_manager/ingestion_consumer.py +++ b/langfuse/_task_manager/ingestion_consumer.py @@ -16,7 +16,7 @@ from langfuse.parse_error import handle_exception from langfuse.request import APIError, LangfuseClient from langfuse.Sampler import Sampler -from langfuse.serializer import EventSerializer +from langfuse.serializer import BaseEventSerializer, EventSerializer from langfuse.types import MaskFunction from .media_manager import MediaManager @@ -48,6 +48,7 @@ class IngestionConsumer(threading.Thread): _mask: Optional[MaskFunction] _sampler: Sampler _media_manager: MediaManager + _serializer: BaseEventSerializer = EventSerializer def __init__( self, @@ -130,7 +131,7 @@ def _next(self): # check for serialization errors try: - json.dumps(event, cls=EventSerializer) + json.dumps(event, cls=self._serializer) except Exception as e: self._log.error(f"Error serializing item, skipping: {e}") self._ingestion_queue.task_done() @@ -223,7 +224,7 @@ def _truncate_item_in_place( def _get_item_size(self, item: Any) -> int: """Return the size of the item in bytes.""" - return len(json.dumps(item, cls=EventSerializer).encode()) + return len(json.dumps(item, cls=self._serializer).encode()) def _apply_mask_in_place(self, event: dict): """Apply the mask function to the event. This is done in place.""" diff --git a/langfuse/decorators/langfuse_decorator.py b/langfuse/decorators/langfuse_decorator.py index 1116336f..2eb49b76 100644 --- a/langfuse/decorators/langfuse_decorator.py +++ b/langfuse/decorators/langfuse_decorator.py @@ -40,7 +40,6 @@ StatefulTraceClient, StateType, ) -from langfuse.serializer import EventSerializer from langfuse.types import ObservationParams, SpanLevel from langfuse.utils import _get_timestamp from langfuse.utils.error_logging import catch_and_log_errors @@ -182,8 +181,8 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]: ) """ - If the decorator is called without arguments, return the decorator function itself. - This allows the decorator to be used with or without arguments. + If the decorator is called without arguments, return the decorator function itself. + This allows the decorator to be used with or without arguments. Python calls the decorator function with the decorated function as an argument when the decorator is used without arguments. """ if func is None: @@ -401,7 +400,7 @@ def _get_input_from_func_args( # Serialize and deserialize to ensure proper JSON serialization. # Objects are later serialized again so deserialization is necessary here to avoid unnecessary escaping of quotes. - return json.loads(json.dumps(raw_input, cls=EventSerializer)) + return json.loads(json.dumps(raw_input, cls=self.client_instance._serializer)) def _finalize_call( self, @@ -457,7 +456,7 @@ def _handle_call_result( json.loads( json.dumps( result if result is not None and capture_output else None, - cls=EventSerializer, + cls=self.client_instance._serializer, ) ) ) diff --git a/langfuse/request.py b/langfuse/request.py index a66b9076..ec7055e8 100644 --- a/langfuse/request.py +++ b/langfuse/request.py @@ -7,7 +7,7 @@ import httpx -from langfuse.serializer import EventSerializer +from langfuse.serializer import BaseEventSerializer class LangfuseClient: @@ -17,6 +17,7 @@ class LangfuseClient: _version: str _timeout: int _session: httpx.Client + _serializer: BaseEventSerializer def __init__( self, @@ -60,7 +61,7 @@ def post(self, **kwargs) -> httpx.Response: """Post the `kwargs` to the API""" log = logging.getLogger("langfuse") url = self._remove_trailing_slash(self._base_url) + "/api/public/ingestion" - data = json.dumps(kwargs, cls=EventSerializer) + data = json.dumps(kwargs, cls=self._serializer) log.debug("making request: %s to %s", data, url) headers = self.generate_headers() res = self._session.post( diff --git a/langfuse/serializer.py b/langfuse/serializer.py index 0b4dfd1b..27173787 100644 --- a/langfuse/serializer.py +++ b/langfuse/serializer.py @@ -1,7 +1,9 @@ """@private""" +import abc import enum import math +from abc import ABC, abstractmethod from asyncio import Queue from collections.abc import Sequence from dataclasses import asdict, is_dataclass @@ -33,12 +35,41 @@ logger = getLogger(__name__) -class EventSerializer(JSONEncoder): +class BaseEventSerializer(JSONEncoder, ABC): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.seen = set() # Track seen objects to detect circular references - def default(self, obj: Any): + @abstractmethod + def default(self, obj: Any) -> Any: + """Convert object to JSON serializable format""" + pass + + def encode(self, obj: Any) -> str: + self.seen.clear() # Clear seen objects before each encode call + + try: + return super().encode(self.default(obj)) + except Exception: + return f'""' # escaping the string to avoid JSON parsing errors + + @staticmethod + def is_js_safe_integer(value: int) -> bool: + """Ensure the value is within JavaScript's safe range for integers. + + Python's 64-bit integers can exceed this range, necessitating this check. + https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER + """ + max_safe_int = 2**53 - 1 + min_safe_int = -(2**53) + 1 + + return min_safe_int <= value <= max_safe_int + + +class EventSerializer(BaseEventSerializer): + def default( + self, obj: Any + ): # -> str | Any | dict[str, Any] | int | float | list | dict | ...: try: if isinstance(obj, (datetime)): # Timezone-awareness check @@ -158,23 +189,3 @@ def default(self, obj: Any): exc_info=e, ) return f'""' - - def encode(self, obj: Any) -> str: - self.seen.clear() # Clear seen objects before each encode call - - try: - return super().encode(self.default(obj)) - except Exception: - return f'""' # escaping the string to avoid JSON parsing errors - - @staticmethod - def is_js_safe_integer(value: int) -> bool: - """Ensure the value is within JavaScript's safe range for integers. - - Python's 64-bit integers can exceed this range, necessitating this check. - https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER - """ - max_safe_int = 2**53 - 1 - min_safe_int = -(2**53) + 1 - - return min_safe_int <= value <= max_safe_int diff --git a/tests/test_serializer.py b/tests/test_serializer.py index e0156153..76f78151 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,14 +1,19 @@ -from datetime import datetime, date, timezone -from uuid import UUID -from enum import Enum +import json +import threading from dataclasses import dataclass +from datetime import date, datetime, timezone +from enum import Enum from pathlib import Path -from pydantic import BaseModel -import json +from typing import Any +from uuid import UUID + +import pandas as pd import pytest -import threading +from pydantic import BaseModel + import langfuse.serializer from langfuse.serializer import ( + BaseEventSerializer, EventSerializer, ) @@ -189,3 +194,21 @@ def test_numpy_float32(): serializer = EventSerializer() assert serializer.encode(data) == "1.0" + + +def test_custom_serializer(): + class CustomSerializer(BaseEventSerializer): + def default(self, obj: Any) -> Any: + if isinstance(obj, pd.DataFrame): + return obj.to_dict(orient="records") + return super().default(obj) + + df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + serializer = CustomSerializer() + result = json.loads(serializer.encode(df)) + + assert result == [ + {"col1": 1, "col2": "a"}, + {"col1": 2, "col2": "b"}, + {"col1": 3, "col2": "c"}, + ]