Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ prost = "0.14"
prost-build = "0.14"
prost-types = "0.14"
pyo3 = { version = "0.26.0" }
pyo3-bytes = "0.4"
pyo3-log = "0.13.0"
rand = "0.9.0"
rand_distr = "0.5"
Expand Down
2 changes: 2 additions & 0 deletions vortex-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ crate-type = ["rlib", "cdylib"]
arrow-array = { workspace = true }
arrow-data = { workspace = true }
arrow-schema = { workspace = true }
bytes = { workspace = true }
itertools = { workspace = true }
log = { workspace = true }
object_store = { workspace = true, features = ["aws", "gcp", "azure", "http"] }
parking_lot = { workspace = true }
pyo3 = { workspace = true, features = ["abi3", "abi3-py311"] }
pyo3-bytes = { workspace = true }
pyo3-log = { workspace = true }
tokio = { workspace = true, features = ["fs", "rt-multi-thread"] }
url = { workspace = true }
Expand Down
7 changes: 7 additions & 0 deletions vortex-python/benchmark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hashlib
import math
import os
from typing import cast

import pyarrow as pa
import pytest
Expand Down Expand Up @@ -34,3 +35,9 @@ def vxf(tmpdir_factory: pytest.TempPathFactory, request: pytest.FixtureRequest)
a = vx.array(pa.table(columns)) # pyright: ignore[reportCallIssue, reportUnknownArgumentType, reportArgumentType]
vx.io.write(a, str(fname))
return vx.open(str(fname))


@pytest.fixture(scope="session", params=[10_000, 2_000_000], ids=["small", "large"])
def array_fixture(request: pytest.FixtureRequest) -> vx.Array:
size = cast(int, request.param)
return vx.array(pa.table({"x": list(range(size))}))
29 changes: 29 additions & 0 deletions vortex-python/benchmark/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright the Vortex contributors

import pickle

import pytest
from pytest_benchmark.fixture import BenchmarkFixture # pyright: ignore[reportMissingTypeStubs]

import vortex as vx


@pytest.mark.parametrize("protocol", [4, 5], ids=lambda p: f"p{p}") # pyright: ignore[reportAny]
@pytest.mark.parametrize("operation", ["dumps", "loads", "roundtrip"])
@pytest.mark.benchmark(disable_gc=True)
def test_pickle(
benchmark: BenchmarkFixture,
array_fixture: vx.Array,
protocol: int,
operation: str,
):
benchmark.group = f"pickle_p{protocol}"

if operation == "dumps":
benchmark(lambda: pickle.dumps(array_fixture, protocol=protocol))
elif operation == "loads":
pickled_data = pickle.dumps(array_fixture, protocol=protocol)
benchmark(lambda: pickle.loads(pickled_data)) # pyright: ignore[reportAny]
elif operation == "roundtrip":
benchmark(lambda: pickle.loads(pickle.dumps(array_fixture, protocol=protocol))) # pyright: ignore[reportAny]
9 changes: 8 additions & 1 deletion vortex-python/python/vortex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@
scalar,
)
from ._lib.serde import ArrayContext, ArrayParts # pyright: ignore[reportMissingModuleSource]
from .arrays import Array, PyArray, array
from .arrays import (
Array,
PyArray,
_unpickle_array, # pyright: ignore[reportPrivateUsage]
array,
)
from .file import VortexFile, open
from .scan import RepeatedScan

Expand Down Expand Up @@ -155,6 +160,8 @@
# Serde
"ArrayContext",
"ArrayParts",
# Pickle
"_unpickle_array",
# File
"VortexFile",
"open",
Expand Down
7 changes: 7 additions & 0 deletions vortex-python/python/vortex/_lib/serde.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright the Vortex contributors

from collections.abc import Sequence
from typing import final

import pyarrow as pa

from .arrays import Array
from .dtype import DType

@final
Expand All @@ -26,3 +28,8 @@ class ArrayParts:
@final
class ArrayContext:
def __len__(self) -> int: ...

def decode_ipc_array(array_bytes: bytes, dtype_bytes: bytes) -> Array: ...
def decode_ipc_array_buffers(
array_buffers: Sequence[bytes | memoryview], dtype_buffers: Sequence[bytes | memoryview]
) -> Array: ...
20 changes: 18 additions & 2 deletions vortex-python/python/vortex/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from __future__ import annotations

import abc
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any

import pyarrow
from typing_extensions import override

import vortex._lib.arrays as _arrays # pyright: ignore[reportMissingModuleSource]
from vortex._lib.dtype import DType # pyright: ignore[reportMissingModuleSource]
from vortex._lib.serde import ArrayContext, ArrayParts # pyright: ignore[reportMissingModuleSource]
from vortex._lib.serde import ( # pyright: ignore[reportMissingModuleSource]
ArrayContext,
ArrayParts,
decode_ipc_array_buffers,
)

try:
import pandas
Expand Down Expand Up @@ -466,3 +470,15 @@ def decode(cls, parts: ArrayParts, ctx: ArrayContext, dtype: DType, len: int) ->
current array. Implementations of this function should validate this information, and then construct
a new array.
"""


def _unpickle_array(array_buffers: Sequence[bytes | memoryview], dtype_buffers: Sequence[bytes | memoryview]) -> Array: # pyright: ignore[reportUnusedFunction]
"""Unpickle a Vortex array from IPC-encoded buffer lists.

This is an internal function used by the pickle module for both protocol 4 and 5.

For protocol 4, receives list[bytes] from __reduce__.
For protocol 5, receives list[PickleBuffer/memoryview] from __reduce_ex__.
Both use decode_ipc_array_buffers which concatenates the buffers during deserialization.
"""
return decode_ipc_array_buffers(array_buffers, dtype_buffers)
75 changes: 75 additions & 0 deletions vortex-python/src/arrays/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ pub(crate) mod py;
mod range_to_sequence;

use arrow_array::{Array as ArrowArray, ArrayRef as ArrowArrayRef};
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyRange, PyRangeMethods};
use pyo3_bytes::PyBytes;
use vortex::arrays::ChunkedVTable;
use vortex::arrow::IntoArrowArray;
use vortex::compute::{Operator, compare, take};
use vortex::dtype::{DType, Nullability, PType, match_each_integer_ptype};
use vortex::error::VortexError;
use vortex::ipc::messages::{EncoderMessage, MessageEncoder};
use vortex::{Array, ArrayRef, ToCanonical};

use crate::arrays::native::PyNativeArray;
Expand Down Expand Up @@ -653,4 +656,76 @@ impl PyArray {
.map(|buffer| buffer.to_vec())
.collect())
}

/// Support for Python's pickle protocol.
///
/// This method serializes the array using Vortex IPC format and returns
/// the data needed for pickle to reconstruct the array.
fn __reduce__<'py>(
slf: &'py Bound<'py, Self>,
) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
let py = slf.py();
let array = PyArrayRef::extract_bound(slf.as_any())?.into_inner();

let mut encoder = MessageEncoder::default();
let buffers = encoder.encode(EncoderMessage::Array(&*array));

// Return buffers as a list instead of concatenating
let array_buffers: Vec<Vec<u8>> = buffers.iter().map(|b| b.to_vec()).collect();

let dtype_buffers = encoder.encode(EncoderMessage::DType(array.dtype()));
let dtype_buffers: Vec<Vec<u8>> = dtype_buffers.iter().map(|b| b.to_vec()).collect();

let vortex_module = PyModule::import(py, "vortex")?;
let unpickle_fn = vortex_module.getattr("_unpickle_array")?;

let args = (array_buffers, dtype_buffers).into_pyobject(py)?;
Ok((unpickle_fn, args.into_any()))
}

/// Support for Python's pickle protocol for protocol >= 5
///
/// uses PickleBuffer for out-of-band buffer transfer,
/// which potentially avoids copying large data buffers.
fn __reduce_ex__<'py>(
slf: &'py Bound<'py, Self>,
protocol: i32,
) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
let py = slf.py();

if protocol < 5 {
return Self::__reduce__(slf);
}

let array = PyArrayRef::extract_bound(slf.as_any())?.into_inner();

let mut encoder = MessageEncoder::default();
let array_buffers = encoder.encode(EncoderMessage::Array(&*array));
let dtype_buffers = encoder.encode(EncoderMessage::DType(array.dtype()));

let pickle_module = PyModule::import(py, "pickle")?;
let pickle_buffer_class = pickle_module.getattr("PickleBuffer")?;

let mut pickle_buffers = Vec::new();
for buf in array_buffers.into_iter() {
// PyBytes wraps bytes::Bytes and implements the buffer protocol
// This allows PickleBuffer to reference the data without copying
let py_bytes = PyBytes::new(buf).into_py_any(py)?;
let pickle_buffer = pickle_buffer_class.call1((py_bytes,))?;
pickle_buffers.push(pickle_buffer);
}

let mut dtype_pickle_buffers = Vec::new();
for buf in dtype_buffers.into_iter() {
let py_bytes = PyBytes::new(buf).into_py_any(py)?;
let pickle_buffer = pickle_buffer_class.call1((py_bytes,))?;
dtype_pickle_buffers.push(pickle_buffer);
}

let vortex_module = PyModule::import(py, "vortex")?;
let unpickle_fn = vortex_module.getattr("_unpickle_array")?;

let args = (pickle_buffers, dtype_pickle_buffers).into_pyobject(py)?;
Ok((unpickle_fn, args.into_any()))
}
}
Loading
Loading