From 24e266cf2a61ccefc49f6f4d1fef458fa386279b Mon Sep 17 00:00:00 2001 From: Fan Jiang Date: Mon, 8 Sep 2025 11:48:19 -0700 Subject: [PATCH] Fixes for Arrow STL iterator for custom types --- cpp/src/arrow/stl_iterator.h | 53 +++++++------ cpp/src/arrow/stl_iterator_test.cc | 121 +++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/stl_iterator.h b/cpp/src/arrow/stl_iterator.h index 577066cba0f..065b5dd4acd 100644 --- a/cpp/src/arrow/stl_iterator.h +++ b/cpp/src/arrow/stl_iterator.h @@ -64,12 +64,13 @@ class ArrayIterator { // Value access value_type operator*() const { assert(array_); - return array_->IsNull(index_) ? value_type{} : array_->GetView(index_); + return array_->IsNull(index_) ? value_type{} : ValueAccessor{}(*array_, index_); } value_type operator[](difference_type n) const { assert(array_); - return array_->IsNull(index_ + n) ? value_type{} : array_->GetView(index_ + n); + return array_->IsNull(index_ + n) ? value_type{} + : ValueAccessor{}(*array_, index_ + n); } int64_t index() const { return index_; } @@ -154,7 +155,7 @@ class ChunkedArrayIterator { // Value access value_type operator*() const { auto chunk_location = GetChunkLocation(index_); - ArrayIterator target_iterator{ + ArrayIterator target_iterator{ arrow::internal::checked_cast( *chunked_array_->chunk(static_cast(chunk_location.chunk_index)))}; return target_iterator[chunk_location.index_in_chunk]; @@ -247,33 +248,39 @@ class ChunkedArrayIterator { }; /// Return an iterator to the beginning of the chunked array -template ::ArrayType> -ChunkedArrayIterator Begin(const ChunkedArray& chunked_array) { - return ChunkedArrayIterator(chunked_array); +template ::ArrayType, + typename ValueAccessor = detail::DefaultValueAccessor> +ChunkedArrayIterator Begin(const ChunkedArray& chunked_array) { + return ChunkedArrayIterator(chunked_array); } /// Return an iterator to the end of the chunked array -template ::ArrayType> -ChunkedArrayIterator End(const ChunkedArray& chunked_array) { - return ChunkedArrayIterator(chunked_array, chunked_array.length()); +template ::ArrayType, + typename ValueAccessor = detail::DefaultValueAccessor> +ChunkedArrayIterator End(const ChunkedArray& chunked_array) { + return ChunkedArrayIterator(chunked_array, + chunked_array.length()); } -template +template > struct ChunkedArrayRange { const ChunkedArray* chunked_array; - ChunkedArrayIterator begin() { - return stl::ChunkedArrayIterator(*chunked_array); + ChunkedArrayIterator begin() { + return stl::ChunkedArrayIterator(*chunked_array); } - ChunkedArrayIterator end() { - return stl::ChunkedArrayIterator(*chunked_array, chunked_array->length()); + ChunkedArrayIterator end() { + return stl::ChunkedArrayIterator(*chunked_array, + chunked_array->length()); } }; /// Return an iterable range over the chunked array -template ::ArrayType> -ChunkedArrayRange Iterate(const ChunkedArray& chunked_array) { - return stl::ChunkedArrayRange{&chunked_array}; +template ::ArrayType, + typename ValueAccessor = detail::DefaultValueAccessor> +ChunkedArrayRange Iterate(const ChunkedArray& chunked_array) { + return stl::ChunkedArrayRange{&chunked_array}; } } // namespace stl @@ -281,9 +288,9 @@ ChunkedArrayRange Iterate(const ChunkedArray& chunked_array) { namespace std { -template -struct iterator_traits<::arrow::stl::ArrayIterator> { - using IteratorType = ::arrow::stl::ArrayIterator; +template +struct iterator_traits<::arrow::stl::ArrayIterator> { + using IteratorType = ::arrow::stl::ArrayIterator; using difference_type = typename IteratorType::difference_type; using value_type = typename IteratorType::value_type; using pointer = typename IteratorType::pointer; @@ -291,9 +298,9 @@ struct iterator_traits<::arrow::stl::ArrayIterator> { using iterator_category = typename IteratorType::iterator_category; }; -template -struct iterator_traits<::arrow::stl::ChunkedArrayIterator> { - using IteratorType = ::arrow::stl::ChunkedArrayIterator; +template +struct iterator_traits<::arrow::stl::ChunkedArrayIterator> { + using IteratorType = ::arrow::stl::ChunkedArrayIterator; using difference_type = typename IteratorType::difference_type; using value_type = typename IteratorType::value_type; using pointer = typename IteratorType::pointer; diff --git a/cpp/src/arrow/stl_iterator_test.cc b/cpp/src/arrow/stl_iterator_test.cc index 3fe57ebc0d4..c2c4ac0202a 100644 --- a/cpp/src/arrow/stl_iterator_test.cc +++ b/cpp/src/arrow/stl_iterator_test.cc @@ -248,6 +248,68 @@ TEST(ArrayIterator, StdMerge) { ASSERT_EQ(values, expected); } +// Custom ValueAccessor for DictionaryArray that decodes values +struct TestDictionaryValueAccessor { + using ValueType = std::string_view; + + inline ValueType operator()(const DictionaryArray& array, int64_t index) { + // Get the dictionary index for this position + int64_t dict_index = array.GetValueIndex(index); + + // Get the dictionary and cast it to StringArray + auto dict = checked_pointer_cast(array.dictionary()); + + // Return the decoded string value + return dict->GetView(dict_index); + } +}; + +TEST(ArrayIterator, CustomValueAccessorDictionary) { + // Create a dictionary array with string values + auto dict = ArrayFromJSON(utf8(), R"(["apple", "banana", "cherry", "date"])"); + auto indices = ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 1, 0, null, 3]"); + + auto dict_type = dictionary(int32(), utf8()); + auto dict_array = std::make_shared(dict_type, indices, dict); + + // Use custom accessor to iterate over decoded values + ArrayIterator it(*dict_array); + + // Test basic access + ASSERT_EQ(*it, "apple"); + ASSERT_EQ(it[1], "banana"); + ASSERT_EQ(it[2], "cherry"); + ASSERT_EQ(it[3], "date"); + ASSERT_EQ(it[4], "cherry"); + ASSERT_EQ(it[5], "banana"); + ASSERT_EQ(it[6], "apple"); + ASSERT_EQ(it[7], nullopt); // null index + ASSERT_EQ(it[8], "date"); + + // Test iteration + std::vector> values; + for (auto end = it + 9; it != end; ++it) { + values.push_back(*it); + } + + std::vector> expected{ + "apple", "banana", "cherry", "date", "cherry", "banana", "apple", nullopt, "date"}; + ASSERT_EQ(values, expected); + + // Test with algorithms - find a specific value + ArrayIterator begin(*dict_array); + ArrayIterator end(*dict_array, + dict_array->length()); + + auto found = std::find(begin, end, "cherry"); + ASSERT_NE(found, end); + ASSERT_EQ(found.index(), 2); // First occurrence of "cherry" + + // Count occurrences of "banana" + auto count = std::count(begin, end, "banana"); + ASSERT_EQ(count, 2); +} + TEST(ChunkedArrayIterator, Basics) { auto result = ChunkedArrayFromJSON(int32(), {R"([4, 5, null])", R"([6])"}); auto it = Begin(*result); @@ -545,5 +607,64 @@ TEST(ChunkedArrayIterator, ForEachIterator) { ASSERT_EQ(values, expected); } +TEST(ChunkedArrayIterator, CustomValueAccessorDictionary) { + // Create multiple dictionary arrays with the same dictionary + auto dict = ArrayFromJSON(utf8(), R"(["red", "green", "blue", "yellow"])"); + + auto indices1 = ArrayFromJSON(int32(), "[0, 1, 2]"); + auto indices2 = ArrayFromJSON(int32(), "[3, 2, null]"); + auto indices3 = ArrayFromJSON(int32(), "[1, 0, 3, 2]"); + + auto dict_type = dictionary(int32(), utf8()); + auto dict_array1 = std::make_shared(dict_type, indices1, dict); + auto dict_array2 = std::make_shared(dict_type, indices2, dict); + auto dict_array3 = std::make_shared(dict_type, indices3, dict); + + // Create chunked array from dictionary arrays + auto chunked_array = std::make_shared( + std::vector>{dict_array1, dict_array2, dict_array3}, + dict_type); + + // Use custom accessor to iterate over decoded values across chunks + auto it = + Begin(*chunked_array); + auto end = + End(*chunked_array); + + // Test sequential access across chunks + ASSERT_EQ(*it, "red"); // chunk 0, index 0 + ASSERT_EQ(*(it + 1), "green"); // chunk 0, index 1 + ASSERT_EQ(*(it + 2), "blue"); // chunk 0, index 2 + ASSERT_EQ(*(it + 3), "yellow"); // chunk 1, index 0 + ASSERT_EQ(*(it + 4), "blue"); // chunk 1, index 1 + ASSERT_EQ(*(it + 5), nullopt); // chunk 1, index 2 (null) + ASSERT_EQ(*(it + 6), "green"); // chunk 2, index 0 + ASSERT_EQ(*(it + 7), "red"); // chunk 2, index 1 + ASSERT_EQ(*(it + 8), "yellow"); // chunk 2, index 2 + ASSERT_EQ(*(it + 9), "blue"); // chunk 2, index 3 + + // Collect all values + std::vector> values; + + for (auto elem : Iterate( + *chunked_array)) { + values.push_back(elem); + } + + std::vector> expected{"red", "green", "blue", "yellow", + "blue", nullopt, "green", "red", + "yellow", "blue"}; + ASSERT_EQ(values, expected); + + // Test with algorithms - count occurrences of "blue" + auto count = std::count(it, end, "blue"); + ASSERT_EQ(count, 3); + + // Find first occurrence of "yellow" + auto found = std::find(it, end, "yellow"); + ASSERT_NE(found, end); + ASSERT_EQ(found.index(), 3); +} + } // namespace stl } // namespace arrow