Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add functionality to pass custom serializer #1116

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions langfuse/_task_manager/ingestion_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from langfuse.parse_error import handle_exception
from langfuse.request import APIError, LangfuseClient
from langfuse.Sampler import Sampler
from langfuse.serializer import EventSerializer
from langfuse.serializer import BaseEventSerializer, EventSerializer
from langfuse.types import MaskFunction

from .media_manager import MediaManager
Expand Down Expand Up @@ -48,6 +48,7 @@ class IngestionConsumer(threading.Thread):
_mask: Optional[MaskFunction]
_sampler: Sampler
_media_manager: MediaManager
_serializer: BaseEventSerializer = EventSerializer

def __init__(
self,
Expand Down Expand Up @@ -130,7 +131,7 @@ def _next(self):

# check for serialization errors
try:
json.dumps(event, cls=EventSerializer)
json.dumps(event, cls=self._serializer)
except Exception as e:
self._log.error(f"Error serializing item, skipping: {e}")
self._ingestion_queue.task_done()
Expand Down Expand Up @@ -223,7 +224,7 @@ def _truncate_item_in_place(

def _get_item_size(self, item: Any) -> int:
"""Return the size of the item in bytes."""
return len(json.dumps(item, cls=EventSerializer).encode())
return len(json.dumps(item, cls=self._serializer).encode())

def _apply_mask_in_place(self, event: dict):
"""Apply the mask function to the event. This is done in place."""
Expand Down
9 changes: 4 additions & 5 deletions langfuse/decorators/langfuse_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
StatefulTraceClient,
StateType,
)
from langfuse.serializer import EventSerializer
from langfuse.types import ObservationParams, SpanLevel
from langfuse.utils import _get_timestamp
from langfuse.utils.error_logging import catch_and_log_errors
Expand Down Expand Up @@ -182,8 +181,8 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]:
)

"""
If the decorator is called without arguments, return the decorator function itself.
This allows the decorator to be used with or without arguments.
If the decorator is called without arguments, return the decorator function itself.
This allows the decorator to be used with or without arguments.
Python calls the decorator function with the decorated function as an argument when the decorator is used without arguments.
"""
if func is None:
Expand Down Expand Up @@ -401,7 +400,7 @@ def _get_input_from_func_args(

# Serialize and deserialize to ensure proper JSON serialization.
# Objects are later serialized again so deserialization is necessary here to avoid unnecessary escaping of quotes.
return json.loads(json.dumps(raw_input, cls=EventSerializer))
return json.loads(json.dumps(raw_input, cls=self.client_instance._serializer))

def _finalize_call(
self,
Expand Down Expand Up @@ -457,7 +456,7 @@ def _handle_call_result(
json.loads(
json.dumps(
result if result is not None and capture_output else None,
cls=EventSerializer,
cls=self.client_instance._serializer,
)
)
)
Expand Down
5 changes: 3 additions & 2 deletions langfuse/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import httpx

from langfuse.serializer import EventSerializer
from langfuse.serializer import BaseEventSerializer


class LangfuseClient:
Expand All @@ -17,6 +17,7 @@ class LangfuseClient:
_version: str
_timeout: int
_session: httpx.Client
_serializer: BaseEventSerializer

def __init__(
self,
Expand Down Expand Up @@ -60,7 +61,7 @@ def post(self, **kwargs) -> httpx.Response:
"""Post the `kwargs` to the API"""
log = logging.getLogger("langfuse")
url = self._remove_trailing_slash(self._base_url) + "/api/public/ingestion"
data = json.dumps(kwargs, cls=EventSerializer)
data = json.dumps(kwargs, cls=self._serializer)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure _serializer is initialized (e.g. self._serializer = EventSerializer) in init as it's used in json.dumps but not set.

log.debug("making request: %s to %s", data, url)
headers = self.generate_headers()
res = self._session.post(
Expand Down
55 changes: 33 additions & 22 deletions langfuse/serializer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""@private"""

import abc
import enum
import math
from abc import ABC, abstractmethod
from asyncio import Queue
from collections.abc import Sequence
from dataclasses import asdict, is_dataclass
Expand Down Expand Up @@ -33,12 +35,41 @@
logger = getLogger(__name__)


class EventSerializer(JSONEncoder):
class BaseEventSerializer(JSONEncoder, ABC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.seen = set() # Track seen objects to detect circular references

def default(self, obj: Any):
@abstractmethod
def default(self, obj: Any) -> Any:
"""Convert object to JSON serializable format"""
pass

def encode(self, obj: Any) -> str:
self.seen.clear() # Clear seen objects before each encode call

try:
return super().encode(self.default(obj))
except Exception:
return f'"<not serializable object of type: {type(obj).__name__}>"' # escaping the string to avoid JSON parsing errors

@staticmethod
def is_js_safe_integer(value: int) -> bool:
"""Ensure the value is within JavaScript's safe range for integers.

Python's 64-bit integers can exceed this range, necessitating this check.
https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER
"""
max_safe_int = 2**53 - 1
min_safe_int = -(2**53) + 1

return min_safe_int <= value <= max_safe_int


class EventSerializer(BaseEventSerializer):
def default(
self, obj: Any
): # -> str | Any | dict[str, Any] | int | float | list | dict | ...:
try:
if isinstance(obj, (datetime)):
# Timezone-awareness check
Expand Down Expand Up @@ -158,23 +189,3 @@ def default(self, obj: Any):
exc_info=e,
)
return f'"<not serializable object of type: {type(obj).__name__}>"'

def encode(self, obj: Any) -> str:
self.seen.clear() # Clear seen objects before each encode call

try:
return super().encode(self.default(obj))
except Exception:
return f'"<not serializable object of type: {type(obj).__name__}>"' # escaping the string to avoid JSON parsing errors

@staticmethod
def is_js_safe_integer(value: int) -> bool:
"""Ensure the value is within JavaScript's safe range for integers.

Python's 64-bit integers can exceed this range, necessitating this check.
https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER
"""
max_safe_int = 2**53 - 1
min_safe_int = -(2**53) + 1

return min_safe_int <= value <= max_safe_int
35 changes: 29 additions & 6 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from datetime import datetime, date, timezone
from uuid import UUID
from enum import Enum
import json
import threading
from dataclasses import dataclass
from datetime import date, datetime, timezone
from enum import Enum
from pathlib import Path
from pydantic import BaseModel
import json
from typing import Any
from uuid import UUID

import pandas as pd
import pytest
import threading
from pydantic import BaseModel

import langfuse.serializer
from langfuse.serializer import (
BaseEventSerializer,
EventSerializer,
)

Expand Down Expand Up @@ -189,3 +194,21 @@ def test_numpy_float32():
serializer = EventSerializer()

assert serializer.encode(data) == "1.0"


def test_custom_serializer():
class CustomSerializer(BaseEventSerializer):
def default(self, obj: Any) -> Any:
if isinstance(obj, pd.DataFrame):
return obj.to_dict(orient="records")
return super().default(obj)

df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})
serializer = CustomSerializer()
result = json.loads(serializer.encode(df))

assert result == [
{"col1": 1, "col2": "a"},
{"col1": 2, "col2": "b"},
{"col1": 3, "col2": "c"},
]