Skip to content

Commit b7a3605

Browse files
authored
Improvements and tests (#57)
* Add compression enum * Check buffer emptiness * Remove unused fcts * Reorganize code * Refactor and add some tests * Add more tests * Add condition to calculate_body_size test
1 parent a5d9846 commit b7a3605

24 files changed

+506
-375
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ set(SPARROW_IPC_SRC
133133
${SPARROW_IPC_SOURCE_DIR}/arrow_interface/arrow_schema/private_data.cpp
134134
${SPARROW_IPC_SOURCE_DIR}/chunk_memory_serializer.cpp
135135
${SPARROW_IPC_SOURCE_DIR}/compression.cpp
136+
${SPARROW_IPC_SOURCE_DIR}/compression_impl.hpp
136137
${SPARROW_IPC_SOURCE_DIR}/deserialize_fixedsizebinary_array.cpp
137138
${SPARROW_IPC_SOURCE_DIR}/deserialize_utils.cpp
138139
${SPARROW_IPC_SOURCE_DIR}/deserialize.cpp

include/sparrow_ipc/chunk_memory_serializer.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88

99
#include <sparrow/record_batch.hpp>
1010

11-
#include "Message_generated.h"
12-
1311
#include "sparrow_ipc/any_output_stream.hpp"
1412
#include "sparrow_ipc/chunk_memory_output_stream.hpp"
13+
#include "sparrow_ipc/compression.hpp"
1514
#include "sparrow_ipc/config/config.hpp"
1615
#include "sparrow_ipc/memory_output_stream.hpp"
1716
#include "sparrow_ipc/serialize.hpp"
@@ -44,8 +43,7 @@ namespace sparrow_ipc
4443
* @param stream Reference to a chunked memory output stream that will receive the serialized chunks
4544
* @param compression Optional: The compression type to use for record batch bodies.
4645
*/
47-
// TODO Use enums and such to avoid including flatbuffers headers
48-
chunk_serializer(chunked_memory_output_stream<std::vector<std::vector<uint8_t>>>& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt);
46+
chunk_serializer(chunked_memory_output_stream<std::vector<std::vector<uint8_t>>>& stream, std::optional<CompressionType> compression = std::nullopt);
4947

5048
/**
5149
* @brief Writes a single record batch to the chunked stream.
@@ -131,7 +129,7 @@ namespace sparrow_ipc
131129
std::vector<sparrow::data_type> m_dtypes;
132130
chunked_memory_output_stream<std::vector<std::vector<uint8_t>>>* m_pstream;
133131
bool m_ended{false};
134-
std::optional<org::apache::arrow::flatbuf::CompressionType> m_compression;
132+
std::optional<CompressionType> m_compression;
135133
};
136134

137135
// Implementation

include/sparrow_ipc/compression.hpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,21 @@
55
#include <variant>
66
#include <vector>
77

8-
#include "Message_generated.h"
9-
108
#include "sparrow_ipc/config/config.hpp"
119

1210
namespace sparrow_ipc
1311
{
14-
// TODO use these later if needed for wrapping purposes (flatbuffers/lz4)
15-
// enum class CompressionType
16-
// {
17-
// NONE,
18-
// LZ4,
19-
// ZSTD
20-
// };
21-
22-
// CompressionType to_compression_type(org::apache::arrow::flatbuf::CompressionType compression_type);
12+
enum class CompressionType
13+
{
14+
LZ4_FRAME,
15+
ZSTD
16+
};
2317

24-
constexpr auto CompressionHeaderSize = sizeof(std::int64_t);
18+
[[nodiscard]] SPARROW_IPC_API std::vector<std::uint8_t> compress(
19+
const CompressionType compression_type,
20+
std::span<const std::uint8_t> data);
2521

26-
[[nodiscard]] SPARROW_IPC_API std::vector<std::uint8_t> compress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data);
27-
[[nodiscard]] SPARROW_IPC_API std::variant<std::vector<std::uint8_t>, std::span<const std::uint8_t>> decompress(const org::apache::arrow::flatbuf::CompressionType compression_type, std::span<const std::uint8_t> data);
22+
[[nodiscard]] SPARROW_IPC_API std::variant<std::vector<std::uint8_t>, std::span<const std::uint8_t>> decompress(
23+
const CompressionType compression_type,
24+
std::span<const std::uint8_t> data);
2825
}

include/sparrow_ipc/deserialize_primitive_array.hpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,7 @@ namespace sparrow_ipc
4343

4444
if (compression)
4545
{
46-
// TODO Handle buffers emptiness thoroughly / which is and which is not allowed...
47-
// Validity buffers can be empty
48-
if (validity_buffer_span.empty())
49-
{
50-
buffers.push_back(validity_buffer_span);
51-
}
52-
else
53-
{
54-
buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression));
55-
}
46+
buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression));
5647
buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression));
5748
}
5849
else

include/sparrow_ipc/deserialize_utils.hpp

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,22 @@ namespace sparrow_ipc::utils
3232
);
3333

3434
/**
35-
* @brief Extracts bitmap pointer and null count from a RecordBatch buffer.
36-
*
37-
* This function retrieves a bitmap buffer from the specified index in the RecordBatch's
38-
* buffer list and calculates the number of null values represented by the bitmap.
35+
* @brief Extracts a buffer from a RecordBatch's body.
3936
*
40-
* @param record_batch The Arrow RecordBatch containing buffer metadata
41-
* @param body The raw buffer data as a byte span
42-
* @param index The index of the bitmap buffer in the RecordBatch's buffer list
37+
* This function retrieves a buffer span from the specified index in the RecordBatch's
38+
* buffer list and increments the index.
4339
*
44-
* @return A pair containing:
45-
* - First: Pointer to the bitmap data (nullptr if buffer is empty)
46-
* - Second: Count of null values in the bitmap (0 if buffer is empty)
40+
* @param record_batch The Arrow RecordBatch containing buffer metadata.
41+
* @param body The raw buffer data as a byte span.
42+
* @param buffer_index The index of the buffer to retrieve. This value is incremented by the function.
4743
*
48-
* @note If the bitmap buffer has zero length, returns {nullptr, 0}
49-
* @note The returned pointer is a non-const cast of the original const data
44+
* @return A `std::span<const uint8_t>` viewing the extracted buffer data.
45+
* @throws std::runtime_error if the buffer metadata indicates a buffer that exceeds the body size.
5046
*/
51-
// TODO to be removed when not used anymore (after adding compression to deserialize_fixedsizebinary_array)
52-
[[nodiscard]] std::pair<std::uint8_t*, int64_t> get_bitmap_pointer_and_null_count(
47+
[[nodiscard]] std::span<const uint8_t> get_buffer(
5348
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
5449
std::span<const uint8_t> body,
55-
size_t index
50+
size_t& buffer_index
5651
);
5752

5853
/**
@@ -72,23 +67,4 @@ namespace sparrow_ipc::utils
7267
std::span<const uint8_t> buffer_span,
7368
const org::apache::arrow::flatbuf::BodyCompression* compression
7469
);
75-
76-
/**
77-
* @brief Extracts a buffer from a RecordBatch's body.
78-
*
79-
* This function retrieves a buffer span from the specified index in the RecordBatch's
80-
* buffer list and increments the index.
81-
*
82-
* @param record_batch The Arrow RecordBatch containing buffer metadata.
83-
* @param body The raw buffer data as a byte span.
84-
* @param buffer_index The index of the buffer to retrieve. This value is incremented by the function.
85-
*
86-
* @return A `std::span<const uint8_t>` viewing the extracted buffer data.
87-
* @throws std::runtime_error if the buffer metadata indicates a buffer that exceeds the body size.
88-
*/
89-
[[nodiscard]] std::span<const uint8_t> get_buffer(
90-
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
91-
std::span<const uint8_t> body,
92-
size_t& buffer_index
93-
);
9470
}

include/sparrow_ipc/deserialize_variable_size_binary_array.hpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,7 @@ namespace sparrow_ipc
4141

4242
if (compression)
4343
{
44-
// Validity buffers can be empty
45-
if (validity_buffer_span.empty())
46-
{
47-
buffers.push_back(validity_buffer_span);
48-
}
49-
else
50-
{
51-
buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression));
52-
}
44+
buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression));
5345
buffers.push_back(utils::get_decompressed_buffer(offset_buffer_span, compression));
5446
buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression));
5547
}

include/sparrow_ipc/flatbuffer_utils.hpp

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#include <sparrow/c_interface.hpp>
66
#include <sparrow/record_batch.hpp>
77

8+
#include "sparrow_ipc/compression.hpp"
9+
#include "sparrow_ipc/utils.hpp"
10+
811
namespace sparrow_ipc
912
{
1013
// Creates a Flatbuffers Decimal type from a format string
@@ -164,6 +167,42 @@ namespace sparrow_ipc
164167
[[nodiscard]] std::vector<org::apache::arrow::flatbuf::FieldNode>
165168
create_fieldnodes(const sparrow::record_batch& record_batch);
166169

170+
namespace details
171+
{
172+
template <typename Func>
173+
void fill_buffers_impl(
174+
const sparrow::arrow_proxy& arrow_proxy,
175+
std::vector<org::apache::arrow::flatbuf::Buffer>& flatbuf_buffers,
176+
int64_t& offset,
177+
Func&& get_buffer_size
178+
)
179+
{
180+
const auto& buffers = arrow_proxy.buffers();
181+
for (const auto& buffer : buffers)
182+
{
183+
int64_t size = get_buffer_size(buffer);
184+
flatbuf_buffers.emplace_back(offset, size);
185+
offset += utils::align_to_8(size);
186+
}
187+
for (const auto& child : arrow_proxy.children())
188+
{
189+
fill_buffers_impl(child, flatbuf_buffers, offset, get_buffer_size);
190+
}
191+
}
192+
193+
template <typename Func>
194+
std::vector<org::apache::arrow::flatbuf::Buffer> get_buffers_impl(const sparrow::record_batch& record_batch, Func&& fill_buffers_func)
195+
{
196+
std::vector<org::apache::arrow::flatbuf::Buffer> buffers;
197+
int64_t offset = 0;
198+
for (const auto& column : record_batch.columns())
199+
{
200+
const auto& arrow_proxy = sparrow::detail::array_access::get_arrow_proxy(column);
201+
fill_buffers_func(arrow_proxy, buffers, offset);
202+
}
203+
return buffers;
204+
}
205+
} // namespace details
167206

168207
/**
169208
* @brief Recursively fills a vector of FlatBuffer Buffer objects with buffer information from an Arrow
@@ -205,6 +244,67 @@ namespace sparrow_ipc
205244
[[nodiscard]] std::vector<org::apache::arrow::flatbuf::Buffer>
206245
get_buffers(const sparrow::record_batch& record_batch);
207246

247+
/**
248+
* @brief Recursively populates a vector with compressed buffer metadata from an Arrow proxy.
249+
*
250+
* This function traverses the Arrow proxy and its children, compressing each buffer and recording
251+
* its metadata (offset and size) in the provided vector. The offset is updated to ensure proper
252+
* alignment for each subsequent buffer.
253+
*
254+
* @param arrow_proxy The Arrow proxy containing the buffers to be compressed.
255+
* @param flatbuf_compressed_buffers A vector to store the resulting compressed buffer metadata.
256+
* @param offset The current offset in the buffer layout, which will be updated by the function.
257+
* @param compression_type The compression algorithm to use.
258+
*/
259+
void fill_compressed_buffers(
260+
const sparrow::arrow_proxy& arrow_proxy,
261+
std::vector<org::apache::arrow::flatbuf::Buffer>& flatbuf_compressed_buffers,
262+
int64_t& offset,
263+
const CompressionType compression_type
264+
);
265+
266+
/**
267+
* @brief Retrieves metadata describing the layout of compressed buffers within a record batch.
268+
*
269+
* This function processes a record batch to determine the metadata (offset and size)
270+
* for each of its buffers, assuming they are compressed using the specified algorithm.
271+
* This metadata accounts for each compressed buffer being prefixed by its 8-byte
272+
* uncompressed size and padded to ensure 8-byte alignment.
273+
*
274+
* @param record_batch The record batch whose buffers' compressed metadata is to be retrieved.
275+
* @param compression_type The compression algorithm that would be applied (e.g., LZ4_FRAME, ZSTD).
276+
* @return A vector of FlatBuffer Buffer objects, each describing the offset and
277+
* size of a corresponding compressed buffer within a larger message body.
278+
*/
279+
[[nodiscard]] std::vector<org::apache::arrow::flatbuf::Buffer>
280+
get_compressed_buffers(const sparrow::record_batch& record_batch, const CompressionType compression_type);
281+
282+
/**
283+
* @brief Calculates the total size of the body section for an Arrow array.
284+
*
285+
* This function recursively computes the total size needed for all buffers
286+
* in an Arrow array structure, including buffers from child arrays. Each
287+
* buffer size is aligned to 8-byte boundaries as required by the Arrow format.
288+
*
289+
* @param arrow_proxy The Arrow array proxy containing buffers and child arrays
290+
* @param compression The compression type to use when serializing
291+
* @return int64_t The total aligned size in bytes of all buffers in the array hierarchy
292+
*/
293+
[[nodiscard]] int64_t calculate_body_size(const sparrow::arrow_proxy& arrow_proxy, std::optional<CompressionType> compression = std::nullopt);
294+
295+
/**
296+
* @brief Calculates the total body size of a record batch by summing the body sizes of all its columns.
297+
*
298+
* This function iterates through all columns in the given record batch and accumulates
299+
* the body size of each column's underlying Arrow array proxy. The body size represents
300+
* the total memory required for the serialized data content of the record batch.
301+
*
302+
* @param record_batch The sparrow record batch containing columns to calculate size for
303+
* @param compression The compression type to use when serializing
304+
* @return int64_t The total body size in bytes of all columns in the record batch
305+
*/
306+
[[nodiscard]] int64_t calculate_body_size(const sparrow::record_batch& record_batch, std::optional<CompressionType> compression = std::nullopt);
307+
208308
/**
209309
* @brief Creates a FlatBuffer message containing a serialized Apache Arrow RecordBatch.
210310
*
@@ -222,5 +322,5 @@ namespace sparrow_ipc
222322
* @note Variadic buffer counts is not currently implemented (set to 0)
223323
*/
224324
[[nodiscard]] flatbuffers::FlatBufferBuilder
225-
get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional<org::apache::arrow::flatbuf::CompressionType> compression = std::nullopt);
325+
get_record_batch_message_builder(const sparrow::record_batch& record_batch, std::optional<CompressionType> compression = std::nullopt);
226326
}

include/sparrow_ipc/serialize.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "Message_generated.h"
88
#include "sparrow_ipc/any_output_stream.hpp"
9+
#include "sparrow_ipc/compression.hpp"
910
#include "sparrow_ipc/config/config.hpp"
1011
#include "sparrow_ipc/magic_values.hpp"
1112
#include "sparrow_ipc/serialize_utils.hpp"
@@ -36,7 +37,7 @@ namespace sparrow_ipc
3637
*/
3738
template <std::ranges::input_range R>
3839
requires std::same_as<std::ranges::range_value_t<R>, sparrow::record_batch>
39-
void serialize_record_batches_to_ipc_stream(const R& record_batches, any_output_stream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression)
40+
void serialize_record_batches_to_ipc_stream(const R& record_batches, any_output_stream& stream, std::optional<CompressionType> compression)
4041
{
4142
if (record_batches.empty())
4243
{
@@ -76,7 +77,7 @@ namespace sparrow_ipc
7677
*/
7778

7879
SPARROW_IPC_API void
79-
serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional<org::apache::arrow::flatbuf::CompressionType> compression);
80+
serialize_record_batch(const sparrow::record_batch& record_batch, any_output_stream& stream, std::optional<CompressionType> compression);
8081

8182
/**
8283
* @brief Serializes a schema message for a record batch into a byte buffer.

0 commit comments

Comments
 (0)