Skip to content

Commit

Permalink
Fix error with multiple nested partition columns on Iceberg (#24629)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyangli34 authored Jan 27, 2025
1 parent e2ec103 commit 4aaea89
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.iceberg;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import io.trino.spi.connector.ConnectorPartitioningHandle;
Expand All @@ -27,6 +28,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
Expand Down Expand Up @@ -60,30 +62,64 @@ public static IcebergPartitioningHandle create(PartitionSpec spec, TypeManager t
return new IcebergPartitioningHandle(false, partitionFields);
}

/**
* Constructs a map of field IDs to data paths.
* The data path for root field is the ordinal position of the partition field under this root field, defined by {@link IcebergMetadata#getWriteLayout}
* The data path for non-root nested fields is the ordinal position in its parent's nested field.
* e.g. for a schema {f1: {f3, f4}, f2, f5}
* when partitioned by f1.f3 and f2, the data paths are {3 : [1,0], 2 : [0]}
* when partitioned by f1.f4 and f5, the data paths are {4 : [0, 1], 5 : [1]}
*/
private static Map<Integer, List<Integer>> buildDataPaths(PartitionSpec spec)
{
Set<Integer> partitionFieldIds = spec.fields().stream().map(PartitionField::sourceId).collect(toImmutableSet());

int channel = 0;
/*
* In this loop, the field ID acts as a placeholder in the first position
* Later, these placeholders will be replaced with the actual channel IDs by the order of its partitioned sub-field ID.
*/
Map<Integer, List<Integer>> fieldInfo = new HashMap<>();
for (Types.NestedField field : spec.schema().asStruct().fields()) {
// Partition fields can only be nested in a struct
if (field.type() instanceof Types.StructType nestedStruct) {
if (buildDataPaths(partitionFieldIds, nestedStruct, new ArrayDeque<>(List.of(channel)), fieldInfo)) {
channel++;
}
buildDataPaths(partitionFieldIds, nestedStruct, new ArrayDeque<>(ImmutableList.of(field.fieldId())), fieldInfo);
}
else if (field.type().isPrimitiveType() && partitionFieldIds.contains(field.fieldId())) {
fieldInfo.put(field.fieldId(), ImmutableList.of(channel));
channel++;
fieldInfo.put(field.fieldId(), ImmutableList.of(field.fieldId()));
}
}
return fieldInfo;

/*
* Replace the root field ID with the actual channel ID.
* Transformation: {fieldId : rootFieldId.structOrdinalX.structOrdinalY} -> {fieldId : channel.structOrdinalX.structOrdinalY}.
* Root field's channelId is assigned sequentially based on the key fieldId.
*/
List<Integer> sortedFieldIds = fieldInfo.keySet().stream()
.sorted()
.collect(toImmutableList());

ImmutableMap.Builder<Integer, List<Integer>> builder = ImmutableMap
.builderWithExpectedSize(sortedFieldIds.size());

Map<Integer, Integer> fieldChannels = new HashMap<>();
AtomicInteger channel = new AtomicInteger();
for (int sortedFieldId : sortedFieldIds) {
List<Integer> dataPath = fieldInfo.get(sortedFieldId);
int fieldChannel = fieldChannels.computeIfAbsent(dataPath.getFirst(), _ -> channel.getAndIncrement());
List<Integer> channelDataPath = ImmutableList.<Integer>builder()
.add(fieldChannel)
.addAll(dataPath.stream()
.skip(1)
.iterator())
.build();
builder.put(sortedFieldId, channelDataPath);
}

return builder.buildOrThrow();
}

private static boolean buildDataPaths(Set<Integer> partitionFieldIds, Types.StructType struct, ArrayDeque<Integer> currentPaths, Map<Integer, List<Integer>> dataPaths)
private static void buildDataPaths(Set<Integer> partitionFieldIds, Types.StructType struct, ArrayDeque<Integer> currentPaths, Map<Integer, List<Integer>> dataPaths)
{
boolean hasPartitionFields = false;
List<Types.NestedField> fields = struct.fields();
for (int fieldOrdinal = 0; fieldOrdinal < fields.size(); fieldOrdinal++) {
Types.NestedField field = fields.get(fieldOrdinal);
Expand All @@ -92,16 +128,14 @@ private static boolean buildDataPaths(Set<Integer> partitionFieldIds, Types.Stru
currentPaths.addLast(fieldOrdinal);
org.apache.iceberg.types.Type type = field.type();
if (type instanceof Types.StructType nestedStruct) {
hasPartitionFields = buildDataPaths(partitionFieldIds, nestedStruct, currentPaths, dataPaths) || hasPartitionFields;
buildDataPaths(partitionFieldIds, nestedStruct, currentPaths, dataPaths);
}
// Map and List types are not supported in partitioning
if (type.isPrimitiveType() && partitionFieldIds.contains(fieldId)) {
else if (type.isPrimitiveType() && partitionFieldIds.contains(fieldId)) {
dataPaths.put(fieldId, ImmutableList.copyOf(currentPaths));
hasPartitionFields = true;
}
currentPaths.removeLast();
}
return hasPartitionFields;
}

public long getCacheKeyHint()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,35 @@ void testEntriesTable()
}
}

@Test
public void testPartitionColumns()
{
try (TestTable testTable = new TestTable(getQueryRunner()::execute, "test_partition_columns", """
WITH (partitioning = ARRAY[
'"r1.f1"',
'bucket(b1, 4)'
]) AS
SELECT
CAST(ROW(1, 2) AS ROW(f1 INTEGER, f2 integeR)) as r1
, CAST('b' AS VARCHAR) as b1""")) {
assertThat(query("SELECT partition FROM \"" + testTable.getName() + "$partitions\""))
.matches("SELECT CAST(ROW(1, 3) AS ROW(\"r1.f1\" INTEGER, b1_bucket INTEGER))");
}

try (TestTable testTable = new TestTable(getQueryRunner()::execute, "test_partition_columns", """
WITH (partitioning = ARRAY[
'"r1.f2"',
'bucket(b1, 4)',
'"r1.f1"'
]) AS
SELECT
CAST(ROW('f1', 'f2') AS ROW(f1 VARCHAR, f2 VARCHAR)) as r1
, CAST('b' AS VARCHAR) as b1""")) {
assertThat(query("SELECT partition FROM \"" + testTable.getName() + "$partitions\""))
.matches("SELECT CAST(ROW('f2', 3, 'f1') AS ROW(\"r1.f2\" VARCHAR, b1_bucket INTEGER, \"r1.f1\" VARCHAR))");
}
}

@Test
void testEntriesPartitionTable()
{
Expand Down

0 comments on commit 4aaea89

Please sign in to comment.