Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial result to AggregateCursor continuation #3254

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
import com.apple.foundationdb.async.AsyncUtil;
import com.apple.foundationdb.record.RecordCursor;
import com.apple.foundationdb.record.RecordCursorContinuation;
import com.apple.foundationdb.record.RecordCursorProto;
import com.apple.foundationdb.record.RecordCursorResult;
import com.apple.foundationdb.record.RecordCursorStartContinuation;
import com.apple.foundationdb.record.RecordCursorVisitor;
import com.apple.foundationdb.record.query.plan.plans.QueryResult;
import com.google.common.base.Verify;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;

import javax.annotation.Nonnull;
Expand All @@ -54,31 +55,43 @@ public class AggregateCursor<M extends Message> implements RecordCursor<QueryRes
// Previous record processed by this cursor
@Nullable
private RecordCursorResult<QueryResult> previousResult;
@Nullable
// when previousResult = row x, lastResult = row (x-1); when previousResult = null, lastResult = null
private RecordCursorResult<QueryResult> lastResult;
// Previous non-empty record processed by this cursor
@Nullable
private RecordCursorResult<QueryResult> previousValidResult;
@Nullable
private RecordCursorProto.PartialAggregationResult partialAggregationResult;
@Nullable
byte[] continuation;

public AggregateCursor(@Nonnull RecordCursor<QueryResult> inner,
@Nonnull final StreamGrouping<M> streamGrouping,
boolean isCreateDefaultOnEmpty) {
boolean isCreateDefaultOnEmpty,
@Nullable byte[] continuation) {
this.inner = inner;
this.streamGrouping = streamGrouping;
this.isCreateDefaultOnEmpty = isCreateDefaultOnEmpty;
this.continuation = continuation;
}

@Nonnull
@Override
public CompletableFuture<RecordCursorResult<QueryResult>> onNext() {
if (previousResult != null && !previousResult.hasNext()) {
// we are done
return CompletableFuture.completedFuture(RecordCursorResult.exhausted());
return CompletableFuture.completedFuture(RecordCursorResult.withoutNextValue(new AggregateCursorContinuation(previousResult.getContinuation()),
previousResult.getNoNextReason()));
}

return AsyncUtil.whileTrue(() -> inner.onNext().thenApply(innerResult -> {
lastResult = previousResult;
previousResult = innerResult;
if (!innerResult.hasNext()) {
if (!isNoRecords() || (isCreateDefaultOnEmpty && streamGrouping.isResultOnEmpty())) {
streamGrouping.finalizeGroup();
// the method streamGrouping.finalizeGroup() computes previousCompleteResult and resets the accumulator
partialAggregationResult = streamGrouping.finalizeGroup();
}
return false;
} else {
Expand All @@ -91,45 +104,60 @@ public CompletableFuture<RecordCursorResult<QueryResult>> onNext() {
return (!groupBreak);
}
}), getExecutor()).thenApply(vignore -> {
if (isNoRecords()) {
// Edge case where there are no records at all
if (isCreateDefaultOnEmpty && streamGrouping.isResultOnEmpty()) {
return RecordCursorResult.withNextValue(QueryResult.ofComputed(streamGrouping.getCompletedGroupResult()), RecordCursorStartContinuation.START);
// either innerResult.hasNext() = false; or groupBreak = true
if (Verify.verifyNotNull(previousResult).hasNext()) {
// in this case groupBreak = true, return aggregated result and continuation, partialAggregationResult = null
// previousValidResult = null happens when 1st row of current scan != last row of last scan, results in groupBreak = true and previousValidResult = null
RecordCursorContinuation c = previousValidResult == null ? new AggregateCursorContinuation(continuation, false) : new AggregateCursorContinuation(previousValidResult.getContinuation());

/*
* Update the previousValidResult to the next continuation even though it hasn't been returned. This is to return the correct continuation when there are single-element groups.
* Below is an example that shows how continuation(previousValidResult) moves:
* Initial: previousResult = null, previousValidResult = null
row0 groupKey0 groupBreak = False previousValidResult = row0 previousResult = row0
row1 groupKey0 groupBreak = False previousValidResult = row1 previousResult = row1
row2 groupKey1 groupBreak = True previousValidResult = row1 previousResult = row2
* returns result (groupKey0, continuation = row1), and set previousValidResult = row2
*
* Now there are 2 scenarios, 1) the current iteration continues; 2) the current iteration stops
* In scenario 1, the iteration continues, it gets to row3:
row3 groupKey2 groupBreak = True previousValidResult = row2 previousResult = row3
* returns result (groupKey1, continuation = row2), and set previousValidResult = row3
*
* In scenario 2, a new iteration starts from row2 (because the last returned continuation = row1), and set initial previousResult = null, previousValidResult = null:
row2 groupKey1 groupBreak = False previousValidResult = row2 previousResult = row2
* (Note that because a new iteration starts, groupBreak = False for row2.)
row3 groupKey2 groupBreak = True previousValidResult = row2 previousResult = row3
* returns result (groupKey1, continuation = row2), and set previousValidResult = row3
*
* Both scenarios returns the correct result, and continuation are both set to row3 in the end, row2 is scanned twice if a new iteration starts.
*/
previousValidResult = previousResult;
return RecordCursorResult.withNextValue(QueryResult.ofComputed(streamGrouping.getCompletedGroupResult()), c);
} else {
// innerResult.hasNext() = false, might stop in the middle of a group
if (Verify.verifyNotNull(previousResult).getNoNextReason() == NoNextReason.SOURCE_EXHAUSTED) {
// exhausted
if (previousValidResult == null && partialAggregationResult == null) {
return RecordCursorResult.exhausted();
} else {
RecordCursorContinuation c = previousValidResult == null ? new AggregateCursorContinuation(continuation, false) : new AggregateCursorContinuation(previousValidResult.getContinuation());
previousValidResult = previousResult;
return RecordCursorResult.withNextValue(QueryResult.ofComputed(streamGrouping.getCompletedGroupResult()), c);
}
} else {
return RecordCursorResult.exhausted();
// stopped in the middle of a group
RecordCursorContinuation currentContinuation = new AggregateCursorContinuation(lastResult.getContinuation(), partialAggregationResult);
previousValidResult = previousResult;
return RecordCursorResult.withoutNextValue(currentContinuation, Verify.verifyNotNull(previousResult).getNoNextReason());
}
}
// Use the last valid result for the continuation as we need non-terminal one here.
RecordCursorContinuation continuation = Verify.verifyNotNull(previousValidResult).getContinuation();
/*
* Update the previousValidResult to the next continuation even though it hasn't been returned. This is to return the correct continuation when there are single-element groups.
* Below is an example that shows how continuation(previousValidResult) moves:
* Initial: previousResult = null, previousValidResult = null
row0 groupKey0 groupBreak = False previousValidResult = row0 previousResult = row0
row1 groupKey0 groupBreak = False previousValidResult = row1 previousResult = row1
row2 groupKey1 groupBreak = True previousValidResult = row1 previousResult = row2
* returns result (groupKey0, continuation = row1), and set previousValidResult = row2
*
* Now there are 2 scenarios, 1) the current iteration continues; 2) the current iteration stops
* In scenario 1, the iteration continues, it gets to row3:
row3 groupKey2 groupBreak = True previousValidResult = row2 previousResult = row3
* returns result (groupKey1, continuation = row2), and set previousValidResult = row3
*
* In scenario 2, a new iteration starts from row2 (because the last returned continuation = row1), and set initial previousResult = null, previousValidResult = null:
row2 groupKey1 groupBreak = False previousValidResult = row2 previousResult = row2
* (Note that because a new iteration starts, groupBreak = False for row2.)
row3 groupKey2 groupBreak = True previousValidResult = row2 previousResult = row3
* returns result (groupKey1, continuation = row2), and set previousValidResult = row3
*
* Both scenarios returns the correct result, and continuation are both set to row3 in the end, row2 is scanned twice if a new iteration starts.
*/
previousValidResult = previousResult;
return RecordCursorResult.withNextValue(QueryResult.ofComputed(streamGrouping.getCompletedGroupResult()), continuation);
});
}



private boolean isNoRecords() {
return ((previousValidResult == null) && (!Verify.verifyNotNull(previousResult).hasNext()));
return ((previousValidResult == null) && (!Verify.verifyNotNull(previousResult).hasNext()) && (streamGrouping.getPartialAggregationResult() == null));
}

@Override
Expand All @@ -155,4 +183,104 @@ public boolean accept(@Nonnull RecordCursorVisitor visitor) {
}
return visitor.visitLeave(this);
}

public static class AggregateCursorContinuation implements RecordCursorContinuation {
@Nullable
private final ByteString innerContinuation;

@Nullable
private final RecordCursorProto.PartialAggregationResult partialAggregationResult;

@Nullable
private RecordCursorProto.AggregateCursorContinuation cachedProto;

private final boolean isEnd;

public AggregateCursorContinuation(@Nonnull RecordCursorContinuation other) {
this(other.toBytes(), other.isEnd());
}

public AggregateCursorContinuation(@Nonnull RecordCursorContinuation other, @Nullable RecordCursorProto.PartialAggregationResult partialAggregationResult) {
this.isEnd = other.isEnd();
this.innerContinuation = other.toByteString();
this.partialAggregationResult = partialAggregationResult;
}

public AggregateCursorContinuation(@Nullable byte[] innerContinuation, boolean isEnd, @Nullable RecordCursorProto.PartialAggregationResult partialAggregationResult) {
this.isEnd = isEnd;
this.innerContinuation = innerContinuation == null ? null : ByteString.copyFrom(innerContinuation);
this.partialAggregationResult = partialAggregationResult;
}

public AggregateCursorContinuation(@Nullable byte[] innerContinuation, boolean isEnd) {
this(innerContinuation, isEnd, null);
}

@Nonnull
@Override
public ByteString toByteString() {
if (isEnd()) {
return ByteString.EMPTY;
} else {
return toProto().toByteString();
}
}

@Nullable
@Override
public byte[] toBytes() {
if (isEnd()) {
return null;
}
return toProto().toByteArray();
}

@Override
public boolean isEnd() {
return isEnd;
}

@Nullable
public byte[] getInnerContinuation() {
return innerContinuation == null ? null : innerContinuation.toByteArray();
}

@Nullable
public RecordCursorProto.PartialAggregationResult getPartialAggregationResult() {
return partialAggregationResult;
}

@Nonnull
private RecordCursorProto.AggregateCursorContinuation toProto() {
if (cachedProto == null) {
RecordCursorProto.AggregateCursorContinuation.Builder cachedProtoBuilder = RecordCursorProto.AggregateCursorContinuation.newBuilder();
if (partialAggregationResult != null) {
cachedProtoBuilder.setPartialAggregationResults(partialAggregationResult);
}
if (innerContinuation != null) {
cachedProtoBuilder.setContinuation(innerContinuation);
}
cachedProto = cachedProtoBuilder.build();
}
return cachedProto;
}

public static AggregateCursorContinuation fromRawBytes(@Nullable byte[] rawBytes) {
if (rawBytes == null) {
return new AggregateCursorContinuation(null, true);
}
try {
RecordCursorProto.AggregateCursorContinuation continuationProto = RecordCursorProto.AggregateCursorContinuation.parseFrom(rawBytes);
if (!continuationProto.hasContinuation()) {
return new AggregateCursorContinuation(null, true);
} else if (continuationProto.hasPartialAggregationResults()) {
return new AggregateCursorContinuation(continuationProto.getContinuation().toByteArray(), false, continuationProto.getPartialAggregationResults());
} else {
return new AggregateCursorContinuation(continuationProto.getContinuation().toByteArray(), false);
}
} catch (final Exception ex) {
throw new RuntimeException(ex);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@

import com.apple.foundationdb.record.Bindings;
import com.apple.foundationdb.record.EvaluationContext;
import com.apple.foundationdb.record.RecordCursorProto;
import com.apple.foundationdb.record.RecordCursorResult;
import com.apple.foundationdb.record.provider.foundationdb.FDBRecordStoreBase;
import com.apple.foundationdb.record.query.plan.cascades.CorrelationIdentifier;
import com.apple.foundationdb.record.query.plan.cascades.values.Accumulator;
import com.apple.foundationdb.record.query.plan.cascades.values.AggregateValue;
import com.apple.foundationdb.record.query.plan.cascades.values.Value;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;

import javax.annotation.Nonnull;
Expand Down Expand Up @@ -104,10 +107,19 @@ public StreamGrouping(@Nullable final Value groupingKeyValue,
@Nonnull final CorrelationIdentifier aggregateAlias,
@Nonnull final FDBRecordStoreBase<M> store,
@Nonnull final EvaluationContext context,
@Nonnull final CorrelationIdentifier alias) {
@Nonnull final CorrelationIdentifier alias,
@Nullable final RecordCursorProto.PartialAggregationResult partialAggregationResult) {
this.groupingKeyValue = groupingKeyValue;
this.aggregateValue = aggregateValue;
this.accumulator = aggregateValue.createAccumulator(context.getTypeRepository());
if (partialAggregationResult != null) {
this.accumulator.setInitialState(partialAggregationResult.getAccumulatorStatesList());
try {
this.currentGroup = DynamicMessage.parseFrom(context.getTypeRepository().newMessageBuilder(groupingKeyValue.getResultType()).getDescriptorForType(), partialAggregationResult.getGroupKey().toByteArray());
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
}
this.store = store;
this.context = context;
this.alias = alias;
Expand Down Expand Up @@ -167,20 +179,22 @@ private boolean isGroupBreak(final Object currentGroup, final Object nextGroup)
}
}

public void finalizeGroup() {
finalizeGroup(null);
public RecordCursorProto.PartialAggregationResult finalizeGroup() {
return finalizeGroup(null);
}

private void finalizeGroup(Object nextGroup) {
private RecordCursorProto.PartialAggregationResult finalizeGroup(Object nextGroup) {
final EvaluationContext nestedContext = context.childBuilder()
.setBinding(groupingKeyAlias, currentGroup)
.setBinding(aggregateAlias, accumulator.finish())
.build(context.getTypeRepository());
previousCompleteResult = completeResultValue.eval(store, nestedContext);

RecordCursorProto.PartialAggregationResult result = currentGroup == null ? null : getPartialAggregationResult((Message) currentGroup);
currentGroup = nextGroup;
// "Reset" the accumulator by creating a fresh one.
accumulator = aggregateValue.createAccumulator(context.getTypeRepository());
return result;
}

private void accumulate(@Nullable Object currentObject) {
Expand All @@ -197,4 +211,13 @@ private Object evalGroupingKey(@Nullable final Object currentObject) {
public boolean isResultOnEmpty() {
return groupingKeyValue == null;
}

@Nullable
public RecordCursorProto.PartialAggregationResult getPartialAggregationResult(@Nonnull Message groupingKey) {
return accumulator.getPartialAggregationResult(groupingKey);
}

public RecordCursorProto.PartialAggregationResult getPartialAggregationResult() {
return accumulator.getPartialAggregationResult((Message)currentGroup);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ public T setHigh(@Nonnull byte[] highBytes, @Nonnull EndpointType highEndpoint)
protected int calculatePrefixLength() {
int prefixLength = subspace.pack().length;
while ((prefixLength < lowBytes.length) &&
(prefixLength < highBytes.length) &&
(lowBytes[prefixLength] == highBytes[prefixLength])) {
(prefixLength < highBytes.length) &&
(lowBytes[prefixLength] == highBytes[prefixLength])) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird auto formating, will remove it in the next commit.

prefixLength++;
}
return prefixLength;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

package com.apple.foundationdb.record.query.plan.cascades.values;

import com.apple.foundationdb.record.RecordCursorProto;
import com.google.protobuf.Message;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.List;

/**
* An aggregate accumulator.
Expand All @@ -29,4 +34,9 @@ public interface Accumulator {
void accumulate(@Nullable Object currentObject);

@Nullable Object finish();

@Nullable
RecordCursorProto.PartialAggregationResult getPartialAggregationResult(Message groupingKey);

void setInitialState(@Nonnull List<RecordCursorProto.AccumulatorState> accumulatorStates);
}
Loading
Loading