Skip to content

Commit

Permalink
Add support for serializing pd.DataFrame in Arrow IPC formats
Browse files Browse the repository at this point in the history
  • Loading branch information
judahrand committed Jun 6, 2024
1 parent 8abc7ad commit 1dd00ad
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 3 deletions.
52 changes: 49 additions & 3 deletions src/bentoml/_internal/io_descriptors/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
pb_v1alpha1, _ = import_generated_stubs("v1alpha1")
pd = LazyLoader("pd", globals(), "pandas", exc_msg=EXC_MSG)
np = LazyLoader("np", globals(), "numpy")
pyarrow = LazyLoader("pyarrow", globals(), "pyarrow")

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -144,6 +145,8 @@ def _series_openapi_schema(
class SerializationFormat(Enum):
JSON = "application/json"
PARQUET = "application/octet-stream"
ARROW_FILE = "application/vnd.apache.arrow.file"
ARROW_STREAM = "application/vnd.apache.arrow.stream"
CSV = "text/csv"

def __init__(self, mime_type: str):
Expand All @@ -156,6 +159,10 @@ def __str__(self) -> str:
return "parquet"
elif self == SerializationFormat.CSV:
return "csv"
elif self == SerializationFormat.ARROW_FILE:
return "arrow_file"
elif self == SerializationFormat.ARROW_STREAM:
return "arrow_stream"
else:
raise ValueError(f"Unknown serialization format: {self}")

Expand All @@ -173,6 +180,10 @@ def _infer_serialization_format_from_request(
return SerializationFormat.PARQUET
elif content_type == "text/csv":
return SerializationFormat.CSV
elif content_type == "application/vnd.apache.arrow.file":
return SerializationFormat.ARROW_FILE
elif content_type == "application/vnd.apache.arrow.stream":
return SerializationFormat.ARROW_STREAM
elif content_type:
logger.debug(
"Unknown Content-Type ('%s'), falling back to '%s' serialization format.",
Expand All @@ -196,6 +207,13 @@ def _validate_serialization_format(serialization_format: SerializationFormat):
raise MissingDependencyException(
"Parquet serialization is not available. Try installing pyarrow or fastparquet first."
)
if (
serialization_format is SerializationFormat.ARROW_FILE
or serialization_format is SerializationFormat.ARROW_STREAM
) and find_spec("pyarrow") is None:
raise MissingDependencyException(
"Arrow serialization is not available. Try installing pyarrow first."
)


class PandasDataFrame(
Expand Down Expand Up @@ -311,6 +329,8 @@ def predict(input_df: pd.DataFrame) -> pd.DataFrame:
- :obj:`json` - JSON text format (inferred from content-type ``"application/json"``)
- :obj:`parquet` - Parquet binary format (inferred from content-type ``"application/octet-stream"``)
- :obj:`csv` - CSV text format (inferred from content-type ``"text/csv"``)
- :obj:`arrow_file` - Arrow file format (inferred from content-type ``"application/vnd.apache.arrow.file"``)
- :obj:`arrow_stream` - Arrow stream format (inferred from content-type ``"application/vnd.apache.arrow.stream"``)
Returns:
:obj:`PandasDataFrame`: IO Descriptor that represents a :code:`pd.DataFrame`.
Expand All @@ -325,7 +345,13 @@ def __init__(
enforce_dtype: bool = False,
shape: tuple[int, ...] | None = None,
enforce_shape: bool = False,
default_format: t.Literal["json", "parquet", "csv"] = "json",
default_format: t.Literal[
"json",
"parquet",
"csv",
"arrow_file",
"arrow_stream",
] = "json",
):
self._orient: ext.DataFrameOrient = orient
self._columns = columns
Expand Down Expand Up @@ -371,6 +397,8 @@ def _from_sample(self, sample: ext.PdDataFrame) -> ext.PdDataFrame:
- :obj:`json` - JSON text format (inferred from content-type ``"application/json"``)
- :obj:`parquet` - Parquet binary format (inferred from content-type ``"application/octet-stream"``)
- :obj:`csv` - CSV text format (inferred from content-type ``"text/csv"``)
- :obj:`arrow_file` - Arrow file format (inferred from content-type ``"application/vnd.apache.arrow.file"``)
- :obj:`arrow_stream` - Arrow stream format (inferred from content-type ``"application/vnd.apache.arrow.stream"``)
Returns:
:class:`~bentoml._internal.io_descriptors.pandas.PandasDataFrame`: IODescriptor from given users inputs.
Expand Down Expand Up @@ -539,6 +567,12 @@ async def from_http_request(self, request: Request) -> ext.PdDataFrame:
res = pd.read_parquet(io.BytesIO(obj), engine=get_parquet_engine())
elif serialization_format is SerializationFormat.CSV:
res: ext.PdDataFrame = pd.read_csv(io.BytesIO(obj), dtype=dtype)
elif serialization_format is SerializationFormat.ARROW_FILE:
with pyarrow.ipc.open_file(obj) as reader:
res = reader.read_pandas()
elif serialization_format is SerializationFormat.ARROW_STREAM:
with pyarrow.ipc.open_stream(obj) as reader:
res = reader.read_pandas()
else:
raise InvalidArgument(
f"Unknown serialization format ({serialization_format})."
Expand Down Expand Up @@ -576,6 +610,18 @@ async def to_http_response(
resp = obj.to_parquet(engine=get_parquet_engine())
elif serialization_format is SerializationFormat.CSV:
resp = obj.to_csv()
elif serialization_format is SerializationFormat.ARROW_FILE:
sink = pyarrow.BufferOutputStream()
batch = self.to_arrow(obj)
with pyarrow.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
resp = sink.getvalue().to_pybytes()
elif serialization_format is SerializationFormat.ARROW_STREAM:
sink = pyarrow.BufferOutputStream()
batch = self.to_arrow(obj)
with pyarrow.ipc.new_stream(sink, batch.schema) as writer:
writer.write_batch(batch)
resp = sink.getvalue().to_pybytes()
else:
raise InvalidArgument(
f"Unknown serialization format ({serialization_format})."
Expand Down Expand Up @@ -743,7 +789,7 @@ def from_arrow(self, batch: pyarrow.RecordBatch) -> ext.PdDataFrame:
def to_arrow(self, df: pd.Series[t.Any]) -> pyarrow.RecordBatch:
import pyarrow

return pyarrow.RecordBatch.from_pandas(df)
return pyarrow.RecordBatch.from_pandas(df, preserve_index=True)

def spark_schema(self) -> pyspark.sql.types.StructType:
from pyspark.pandas.typedef import as_spark_type
Expand Down Expand Up @@ -1201,7 +1247,7 @@ def to_arrow(self, series: pd.Series[t.Any]) -> pyarrow.RecordBatch:
import pyarrow

df = series.to_frame()
return pyarrow.RecordBatch.from_pandas(df)
return pyarrow.RecordBatch.from_pandas(df, preserve_index=True)

def spark_schema(self) -> pyspark.sql.types.StructType:
from pyspark.pandas.typedef import as_spark_type
Expand Down
31 changes: 31 additions & 0 deletions tests/e2e/bento_server_http/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Tuple

import numpy as np
import pyarrow
import pytest

from bentoml.client import AsyncHTTPClient
Expand Down Expand Up @@ -144,6 +145,36 @@ async def test_pandas(host: str):
assert response.status_code == 200
assert await response.aread() == b'[{"col1":202}]'

headers = {
"Content-Type": "application/vnd.apache.arrow.stream",
"Origin": ORIGIN,
}
sink = pyarrow.BufferOutputStream()
batch = pyarrow.RecordBatch.from_pandas(df, preserve_index=True)
with pyarrow.ipc.new_stream(sink, batch.schema) as writer:
writer.write_batch(batch)
data = sink.getvalue().to_pybytes()
response = await client.client.post(
"/predict_dataframe", headers=headers, data=data
)
assert response.status_code == 200
assert await response.aread() == b'[{"col1":202}]'

headers = {
"Content-Type": "application/vnd.apache.arrow.file",
"Origin": ORIGIN,
}
sink = pyarrow.BufferOutputStream()
batch = pyarrow.RecordBatch.from_pandas(df, preserve_index=True)
with pyarrow.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
data = sink.getvalue().to_pybytes()
response = await client.client.post(
"/predict_dataframe", headers=headers, data=data
)
assert response.status_code == 200
assert await response.aread() == b'[{"col1":202}]'


@pytest.mark.asyncio
async def test_file(host: str, bin_file: str):
Expand Down

0 comments on commit 1dd00ad

Please sign in to comment.