diff --git a/Cargo.lock b/Cargo.lock index 5c2a93fb75e..5306d542d01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6291,6 +6291,16 @@ dependencies = [ "target-lexicon", ] +[[package]] +name = "pyo3-bytes" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f01f356a8686c821ce6ec5ac74a442ffbda3ce1fb26a113d4120fd78faf9f726" +dependencies = [ + "bytes", + "pyo3", +] + [[package]] name = "pyo3-ffi" version = "0.26.0" @@ -9227,11 +9237,13 @@ dependencies = [ "arrow-array", "arrow-data", "arrow-schema", + "bytes", "itertools 0.14.0", "log", "object_store", "parking_lot", "pyo3", + "pyo3-bytes", "pyo3-log", "tokio", "url", diff --git a/Cargo.toml b/Cargo.toml index 2972a0ae1f1..455e1601a86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -163,6 +163,7 @@ prost = "0.14" prost-build = "0.14" prost-types = "0.14" pyo3 = { version = "0.26.0" } +pyo3-bytes = "0.4" pyo3-log = "0.13.0" rand = "0.9.0" rand_distr = "0.5" diff --git a/vortex-python/Cargo.toml b/vortex-python/Cargo.toml index b79118f69c8..b5719863b1f 100644 --- a/vortex-python/Cargo.toml +++ b/vortex-python/Cargo.toml @@ -25,11 +25,13 @@ crate-type = ["rlib", "cdylib"] arrow-array = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } +bytes = { workspace = true } itertools = { workspace = true } log = { workspace = true } object_store = { workspace = true, features = ["aws", "gcp", "azure", "http"] } parking_lot = { workspace = true } pyo3 = { workspace = true, features = ["abi3", "abi3-py311"] } +pyo3-bytes = { workspace = true } pyo3-log = { workspace = true } tokio = { workspace = true, features = ["fs", "rt-multi-thread"] } url = { workspace = true } diff --git a/vortex-python/benchmark/conftest.py b/vortex-python/benchmark/conftest.py index c6c65fd8ac1..de4a9df067a 100644 --- a/vortex-python/benchmark/conftest.py +++ b/vortex-python/benchmark/conftest.py @@ -4,6 +4,7 @@ import hashlib import math import os +from typing import cast import pyarrow as pa import pytest @@ -34,3 +35,9 @@ def vxf(tmpdir_factory: pytest.TempPathFactory, request: pytest.FixtureRequest) a = vx.array(pa.table(columns)) # pyright: ignore[reportCallIssue, reportUnknownArgumentType, reportArgumentType] vx.io.write(a, str(fname)) return vx.open(str(fname)) + + +@pytest.fixture(scope="session", params=[10_000, 2_000_000], ids=["small", "large"]) +def array_fixture(request: pytest.FixtureRequest) -> vx.Array: + size = cast(int, request.param) + return vx.array(pa.table({"x": list(range(size))})) diff --git a/vortex-python/benchmark/test_serialization.py b/vortex-python/benchmark/test_serialization.py new file mode 100644 index 00000000000..60ed46e4350 --- /dev/null +++ b/vortex-python/benchmark/test_serialization.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright the Vortex contributors + +import pickle + +import pytest +from pytest_benchmark.fixture import BenchmarkFixture # pyright: ignore[reportMissingTypeStubs] + +import vortex as vx + + +@pytest.mark.parametrize("protocol", [4, 5], ids=lambda p: f"p{p}") # pyright: ignore[reportAny] +@pytest.mark.parametrize("operation", ["dumps", "loads", "roundtrip"]) +@pytest.mark.benchmark(disable_gc=True) +def test_pickle( + benchmark: BenchmarkFixture, + array_fixture: vx.Array, + protocol: int, + operation: str, +): + benchmark.group = f"pickle_p{protocol}" + + if operation == "dumps": + benchmark(lambda: pickle.dumps(array_fixture, protocol=protocol)) + elif operation == "loads": + pickled_data = pickle.dumps(array_fixture, protocol=protocol) + benchmark(lambda: pickle.loads(pickled_data)) # pyright: ignore[reportAny] + elif operation == "roundtrip": + benchmark(lambda: pickle.loads(pickle.dumps(array_fixture, protocol=protocol))) # pyright: ignore[reportAny] diff --git a/vortex-python/python/vortex/__init__.py b/vortex-python/python/vortex/__init__.py index 56b05201d82..90d60b379c2 100644 --- a/vortex-python/python/vortex/__init__.py +++ b/vortex-python/python/vortex/__init__.py @@ -70,7 +70,12 @@ scalar, ) from ._lib.serde import ArrayContext, ArrayParts # pyright: ignore[reportMissingModuleSource] -from .arrays import Array, PyArray, array +from .arrays import ( + Array, + PyArray, + _unpickle_array, # pyright: ignore[reportPrivateUsage] + array, +) from .file import VortexFile, open from .scan import RepeatedScan @@ -155,6 +160,8 @@ # Serde "ArrayContext", "ArrayParts", + # Pickle + "_unpickle_array", # File "VortexFile", "open", diff --git a/vortex-python/python/vortex/_lib/serde.pyi b/vortex-python/python/vortex/_lib/serde.pyi index a3b5316bb1e..626946a7e6c 100644 --- a/vortex-python/python/vortex/_lib/serde.pyi +++ b/vortex-python/python/vortex/_lib/serde.pyi @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright the Vortex contributors +from collections.abc import Sequence from typing import final import pyarrow as pa +from .arrays import Array from .dtype import DType @final @@ -26,3 +28,8 @@ class ArrayParts: @final class ArrayContext: def __len__(self) -> int: ... + +def decode_ipc_array(array_bytes: bytes, dtype_bytes: bytes) -> Array: ... +def decode_ipc_array_buffers( + array_buffers: Sequence[bytes | memoryview], dtype_buffers: Sequence[bytes | memoryview] +) -> Array: ... diff --git a/vortex-python/python/vortex/arrays.py b/vortex-python/python/vortex/arrays.py index 831eb612065..a126d94b6a5 100644 --- a/vortex-python/python/vortex/arrays.py +++ b/vortex-python/python/vortex/arrays.py @@ -3,7 +3,7 @@ from __future__ import annotations import abc -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any import pyarrow @@ -11,7 +11,11 @@ import vortex._lib.arrays as _arrays # pyright: ignore[reportMissingModuleSource] from vortex._lib.dtype import DType # pyright: ignore[reportMissingModuleSource] -from vortex._lib.serde import ArrayContext, ArrayParts # pyright: ignore[reportMissingModuleSource] +from vortex._lib.serde import ( # pyright: ignore[reportMissingModuleSource] + ArrayContext, + ArrayParts, + decode_ipc_array_buffers, +) try: import pandas @@ -466,3 +470,15 @@ def decode(cls, parts: ArrayParts, ctx: ArrayContext, dtype: DType, len: int) -> current array. Implementations of this function should validate this information, and then construct a new array. """ + + +def _unpickle_array(array_buffers: Sequence[bytes | memoryview], dtype_buffers: Sequence[bytes | memoryview]) -> Array: # pyright: ignore[reportUnusedFunction] + """Unpickle a Vortex array from IPC-encoded buffer lists. + + This is an internal function used by the pickle module for both protocol 4 and 5. + + For protocol 4, receives list[bytes] from __reduce__. + For protocol 5, receives list[PickleBuffer/memoryview] from __reduce_ex__. + Both use decode_ipc_array_buffers which concatenates the buffers during deserialization. + """ + return decode_ipc_array_buffers(array_buffers, dtype_buffers) diff --git a/vortex-python/src/arrays/mod.rs b/vortex-python/src/arrays/mod.rs index baa9b144750..bd934c8b24d 100644 --- a/vortex-python/src/arrays/mod.rs +++ b/vortex-python/src/arrays/mod.rs @@ -10,14 +10,17 @@ pub(crate) mod py; mod range_to_sequence; use arrow_array::{Array as ArrowArray, ArrayRef as ArrowArrayRef}; +use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyRange, PyRangeMethods}; +use pyo3_bytes::PyBytes; use vortex::arrays::ChunkedVTable; use vortex::arrow::IntoArrowArray; use vortex::compute::{Operator, compare, take}; use vortex::dtype::{DType, Nullability, PType, match_each_integer_ptype}; use vortex::error::VortexError; +use vortex::ipc::messages::{EncoderMessage, MessageEncoder}; use vortex::{Array, ArrayRef, ToCanonical}; use crate::arrays::native::PyNativeArray; @@ -653,4 +656,76 @@ impl PyArray { .map(|buffer| buffer.to_vec()) .collect()) } + + /// Support for Python's pickle protocol. + /// + /// This method serializes the array using Vortex IPC format and returns + /// the data needed for pickle to reconstruct the array. + fn __reduce__<'py>( + slf: &'py Bound<'py, Self>, + ) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> { + let py = slf.py(); + let array = PyArrayRef::extract_bound(slf.as_any())?.into_inner(); + + let mut encoder = MessageEncoder::default(); + let buffers = encoder.encode(EncoderMessage::Array(&*array)); + + // Return buffers as a list instead of concatenating + let array_buffers: Vec> = buffers.iter().map(|b| b.to_vec()).collect(); + + let dtype_buffers = encoder.encode(EncoderMessage::DType(array.dtype())); + let dtype_buffers: Vec> = dtype_buffers.iter().map(|b| b.to_vec()).collect(); + + let vortex_module = PyModule::import(py, "vortex")?; + let unpickle_fn = vortex_module.getattr("_unpickle_array")?; + + let args = (array_buffers, dtype_buffers).into_pyobject(py)?; + Ok((unpickle_fn, args.into_any())) + } + + /// Support for Python's pickle protocol for protocol >= 5 + /// + /// uses PickleBuffer for out-of-band buffer transfer, + /// which potentially avoids copying large data buffers. + fn __reduce_ex__<'py>( + slf: &'py Bound<'py, Self>, + protocol: i32, + ) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> { + let py = slf.py(); + + if protocol < 5 { + return Self::__reduce__(slf); + } + + let array = PyArrayRef::extract_bound(slf.as_any())?.into_inner(); + + let mut encoder = MessageEncoder::default(); + let array_buffers = encoder.encode(EncoderMessage::Array(&*array)); + let dtype_buffers = encoder.encode(EncoderMessage::DType(array.dtype())); + + let pickle_module = PyModule::import(py, "pickle")?; + let pickle_buffer_class = pickle_module.getattr("PickleBuffer")?; + + let mut pickle_buffers = Vec::new(); + for buf in array_buffers.into_iter() { + // PyBytes wraps bytes::Bytes and implements the buffer protocol + // This allows PickleBuffer to reference the data without copying + let py_bytes = PyBytes::new(buf).into_py_any(py)?; + let pickle_buffer = pickle_buffer_class.call1((py_bytes,))?; + pickle_buffers.push(pickle_buffer); + } + + let mut dtype_pickle_buffers = Vec::new(); + for buf in dtype_buffers.into_iter() { + let py_bytes = PyBytes::new(buf).into_py_any(py)?; + let pickle_buffer = pickle_buffer_class.call1((py_bytes,))?; + dtype_pickle_buffers.push(pickle_buffer); + } + + let vortex_module = PyModule::import(py, "vortex")?; + let unpickle_fn = vortex_module.getattr("_unpickle_array")?; + + let args = (pickle_buffers, dtype_pickle_buffers).into_pyobject(py)?; + Ok((unpickle_fn, args.into_any())) + } } diff --git a/vortex-python/src/serde/mod.rs b/vortex-python/src/serde/mod.rs index 13ba4c6cf02..ab4462c91d7 100644 --- a/vortex-python/src/serde/mod.rs +++ b/vortex-python/src/serde/mod.rs @@ -4,12 +4,17 @@ pub(crate) mod context; pub(crate) mod parts; +use bytes::Bytes; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::{Bound, Python}; +use vortex::ArraySessionExt; +use vortex::ipc::messages::{DecoderMessage, MessageDecoder, PollRead}; -use crate::install_module; +use crate::arrays::PyArrayRef; use crate::serde::context::PyArrayContext; use crate::serde::parts::PyArrayParts; +use crate::{SESSION, install_module}; /// Register serde functions and classes. pub(crate) fn init(py: Python, parent: &Bound) -> PyResult<()> { @@ -19,6 +24,137 @@ pub(crate) fn init(py: Python, parent: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_function(wrap_pyfunction!(decode_ipc_array, &m)?)?; + m.add_function(wrap_pyfunction!(decode_ipc_array_buffers, &m)?)?; Ok(()) } + +/// Decode a Vortex array from IPC-encoded bytes. +/// +/// This function decodes both the dtype and array messages from IPC format +/// and returns the reconstructed array. +/// +/// Parameters +/// ---------- +/// array_bytes : bytes +/// The IPC-encoded array message +/// dtype_bytes : bytes +/// The IPC-encoded dtype message +/// +/// Returns +/// ------- +/// Array +/// The decoded Vortex array +#[pyfunction] +fn decode_ipc_array(array_bytes: Vec, dtype_bytes: Vec) -> PyResult { + let registry = SESSION.arrays().registry().clone(); + let mut decoder = MessageDecoder::new(registry); + + let mut dtype_buf = Bytes::from(dtype_bytes); + let dtype = match decoder.read_next(&mut dtype_buf)? { + PollRead::Some(DecoderMessage::DType(dtype)) => dtype, + PollRead::Some(_) => { + return Err(PyValueError::new_err("Expected DType message")); + } + PollRead::NeedMore(_) => { + return Err(PyValueError::new_err("Incomplete DType message")); + } + }; + + let mut array_buf = Bytes::from(array_bytes); + let array = match decoder.read_next(&mut array_buf)? { + PollRead::Some(DecoderMessage::Array((parts, ctx, row_count))) => { + parts.decode(&ctx, &dtype, row_count)? + } + PollRead::Some(_) => { + return Err(PyValueError::new_err("Expected Array message")); + } + PollRead::NeedMore(_) => { + return Err(PyValueError::new_err("Incomplete Array message")); + } + }; + + Ok(PyArrayRef::from(array)) +} + +/// Decode a Vortex array from IPC-encoded buffer protocol objects +/// +/// This function accepts lists of buffer protocol objects (memoryviews) and decodes +/// them without copying by using PyO3's buffer protocol support. +/// +/// Parameters +/// ---------- +/// array_buffers : list of buffer protocol objects +/// List of IPC-encoded array message buffers +/// dtype_buffers : list of buffer protocol objects +/// List of IPC-encoded dtype message buffers +/// +/// Returns +/// ------- +/// Array +/// The decoded Vortex array +#[pyfunction] +fn decode_ipc_array_buffers<'py>( + py: Python<'py>, + array_buffers: Vec>, + dtype_buffers: Vec>, +) -> PyResult { + use pyo3::buffer::PyBuffer; + + let registry = SESSION.arrays().registry().clone(); + let mut decoder = MessageDecoder::new(registry); + + // Concatenate dtype buffers + // Note: PyBuffer returns &[ReadOnlyCell] which requires copying to get &[u8] + let mut dtype_bytes_vec = Vec::new(); + for buf_obj in dtype_buffers { + let buffer = PyBuffer::::get(&buf_obj)?; + let slice = buffer + .as_slice(py) + .ok_or_else(|| PyValueError::new_err("Buffer is not contiguous"))?; + for cell in slice { + dtype_bytes_vec.push(cell.get()); + } + } + let mut dtype_buf = Bytes::from(dtype_bytes_vec); + + // Decode dtype + let dtype = match decoder.read_next(&mut dtype_buf)? { + PollRead::Some(DecoderMessage::DType(dtype)) => dtype, + PollRead::Some(_) => { + return Err(PyValueError::new_err("Expected DType message")); + } + PollRead::NeedMore(_) => { + return Err(PyValueError::new_err("Incomplete DType message")); + } + }; + + // Concatenate array buffers + let mut array_bytes_vec = Vec::new(); + for buf_obj in array_buffers { + let buffer = PyBuffer::::get(&buf_obj)?; + let slice = buffer + .as_slice(py) + .ok_or_else(|| PyValueError::new_err("Buffer is not contiguous"))?; + for cell in slice { + array_bytes_vec.push(cell.get()); + } + } + let mut array_buf = Bytes::from(array_bytes_vec); + + // Decode array + let array = match decoder.read_next(&mut array_buf)? { + PollRead::Some(DecoderMessage::Array((parts, ctx, row_count))) => { + parts.decode(&ctx, &dtype, row_count)? + } + PollRead::Some(_) => { + return Err(PyValueError::new_err("Expected Array message")); + } + PollRead::NeedMore(_) => { + return Err(PyValueError::new_err("Incomplete Array message")); + } + }; + + Ok(PyArrayRef::from(array)) +} diff --git a/vortex-python/test/test_pickle.py b/vortex-python/test/test_pickle.py new file mode 100644 index 00000000000..a5ddf079755 --- /dev/null +++ b/vortex-python/test/test_pickle.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright the Vortex contributors + +import pickle + +import pyarrow as pa + +import vortex as vx + + +def test_pickle_simple_array(): + arr = vx.array([1, 2, 3, 4, 5]) + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_array_with_nulls(): + arr = vx.array([1, None, 3, None, 5]) + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_string_array(): + arr = vx.array(["hello", "world", "foo", "bar"]) + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_struct_array(): + arr = vx.array( + [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + {"name": "Charlie", "age": 35}, + ] + ) + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_table() == arr.to_arrow_table() + + +def test_pickle_chunked_array(): + arr = vx.array(pa.chunked_array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_large_array(): + arr = vx.array(list(range(100_000))) + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_empty_array(): + arr = vx.array(pa.array([], type=pa.int64())) + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_different_protocols(): + arr = vx.array([1, 2, 3, 4, 5]) + + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + pickled = pickle.dumps(arr, protocol=protocol) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_preserves_dtype(): + arr = vx.array([1, 2, 3, 4, 5]) + original_dtype = arr.dtype + + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert str(restored.dtype) == str(original_dtype) + + +def test_pickle_float_array(): + arr = vx.array([1.5, 2.7, 3.14, 4.0, 5.5]) + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_binary_array(): + arr = vx.array([b"hello", b"world", b"foo"]) + pickled = pickle.dumps(arr) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_protocol_5_simple(): + arr = vx.array([1, 2, 3, 4, 5]) + pickled = pickle.dumps(arr, protocol=5) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_protocol_5_with_nulls(): + arr = vx.array([1, None, 3, None, 5]) + pickled = pickle.dumps(arr, protocol=5) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_protocol_5_large_array(): + arr = vx.array(list(range(1_000_000))) + pickled = pickle.dumps(arr, protocol=5) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_protocol_5_string_array(): + arr = vx.array(["hello", "world", "protocol", "five"]) + pickled = pickle.dumps(arr, protocol=5) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array() + + +def test_pickle_protocol_5_struct_array(): + arr = vx.array( + [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + {"name": "Charlie", "age": 35}, + ] + ) + pickled = pickle.dumps(arr, protocol=5) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_table() == arr.to_arrow_table() + + +def test_pickle_protocol_comparison(): + arr = vx.array(list(range(10_000))) + + pickled_p4 = pickle.dumps(arr, protocol=4) + pickled_p5 = pickle.dumps(arr, protocol=5) + + restored_p4: vx.Array = pickle.loads(pickled_p4) # pyright: ignore[reportAny] + restored_p5: vx.Array = pickle.loads(pickled_p5) # pyright: ignore[reportAny] + + assert restored_p4.to_arrow_array() == arr.to_arrow_array() + assert restored_p5.to_arrow_array() == arr.to_arrow_array() + assert restored_p4.to_arrow_array() == restored_p5.to_arrow_array() + + +def test_pickle_protocol_5_preserves_dtype(): + arr = vx.array([1.5, 2.7, 3.14]) + original_dtype = arr.dtype + + pickled = pickle.dumps(arr, protocol=5) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert str(restored.dtype) == str(original_dtype) + + +def test_pickle_protocol_5_chunked_array(): + arr = vx.array(pa.chunked_array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + pickled = pickle.dumps(arr, protocol=5) + restored: vx.Array = pickle.loads(pickled) # pyright: ignore[reportAny] + + assert restored.to_arrow_array() == arr.to_arrow_array()