Skip to content

Commit 004feb7

Browse files
sandeshkr419Peter Alfonsi
authored andcommitted
multi terms aggregation changes (opensearch-project#19284)
Signed-off-by: Sandesh Kumar <[email protected]>
1 parent 0971319 commit 004feb7

File tree

12 files changed

+616
-44
lines changed

12 files changed

+616
-44
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
3636
- Implement GRPC Boolean query and inject registry for all internal query converters ([#19391](https://github.com/opensearch-project/OpenSearch/pull/19391))
3737
- Implement GRPC Script query ([#19455](https://github.com/opensearch-project/OpenSearch/pull/19455))
3838
- [Search Stats] Add search & star-tree search query failure count metrics ([#19210](https://github.com/opensearch-project/OpenSearch/issues/19210))
39+
- [Star-tree] Support for multi-terms aggregation ([#18398](https://github.com/opensearch-project/OpenSearch/issues/18398))
3940

4041
### Changed
4142
- Refactor `if-else` chains to use `Java 17 pattern matching switch expressions`(([#18965](https://github.com/opensearch-project/OpenSearch/pull/18965))

server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/iterator/SortedNumericStarTreeValuesIterator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public long nextValue() throws IOException {
3535
return ((SortedNumericDocValues) docIdSetIterator).nextValue();
3636
}
3737

38-
public int entryValueCount() throws IOException {
38+
public int entryValueCount() {
3939
return ((SortedNumericDocValues) docIdSetIterator).docValueCount();
4040
}
4141

server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/iterator/SortedSetStarTreeValuesIterator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public long nextOrd() throws IOException {
4444
return ((SortedSetDocValues) docIdSetIterator).nextOrd();
4545
}
4646

47-
public int docValueCount() {
47+
public int entryValueCount() {
4848
return ((SortedSetDocValues) docIdSetIterator).docValueCount();
4949
}
5050

server/src/main/java/org/opensearch/index/compositeindex/datacube/startree/utils/iterator/StarTreeValuesIterator.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ public int advance(int target) throws IOException {
4242
return docIdSetIterator.advance(target);
4343
}
4444

45+
public abstract int entryValueCount();
46+
4547
public long cost() {
4648
return docIdSetIterator.cost();
4749
}

server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ public void collectStarTreeEntry(int starTreeEntry, long owningBucketOrd) throws
366366
if (valuesIterator.advanceExact(starTreeEntry) == false) {
367367
return;
368368
}
369-
for (int i = 0, count = valuesIterator.docValueCount(); i < count; i++) {
369+
for (int i = 0, count = valuesIterator.entryValueCount(); i < count; i++) {
370370
long dimensionValue = valuesIterator.value();
371371
long ord = globalOperator.applyAsLong(dimensionValue);
372372
if (docCountsIterator.advanceExact(starTreeEntry)) {

server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregationFactory.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public class MultiTermsAggregationFactory extends AggregatorFactory {
4848
private final Aggregator.SubAggCollectionMode collectMode;
4949
private final TermsAggregator.BucketCountThresholds bucketCountThresholds;
5050
private final boolean showTermDocCountError;
51+
private final List<String> requestFields;
5152

5253
public static void registerAggregators(ValuesSourceRegistry.Builder builder) {
5354
builder.register(REGISTRY_KEY, List.of(CoreValuesSourceType.BYTES, CoreValuesSourceType.IP), config -> {
@@ -116,6 +117,7 @@ public MultiTermsAggregationFactory(
116117
)
117118
)
118119
.collect(Collectors.toList());
120+
this.requestFields = multiTermConfigs.stream().map(MultiTermsValuesSourceConfig::getFieldName).toList();
119121
this.formats = this.configs.stream().map(c -> c.v1().format()).collect(Collectors.toList());
120122
this.order = order;
121123
this.collectMode = collectMode;
@@ -138,14 +140,17 @@ protected Aggregator createInternal(
138140
// counting
139141
bucketCountThresholds.setShardSize(BucketUtils.suggestShardSideQueueSize(bucketCountThresholds.getRequiredSize()));
140142
}
143+
// TODO: Optimize passing too many value source config derived objects to aggregator
141144
bucketCountThresholds.ensureValidity();
142145
return new MultiTermsAggregator(
143146
name,
144147
factories,
145148
showTermDocCountError,
149+
configs.stream().map(config -> config.v1().getValuesSource()).toList(),
146150
configs.stream()
147151
.map(config -> queryShardContext.getValuesSourceRegistry().getAggregator(REGISTRY_KEY, config.v1()).build(config))
148152
.collect(Collectors.toList()),
153+
this.getRequestFields(),
149154
configs.stream().map(c -> c.v1().format()).collect(Collectors.toList()),
150155
order,
151156
collectMode,
@@ -157,6 +162,10 @@ protected Aggregator createInternal(
157162
);
158163
}
159164

165+
public List<String> getRequestFields() {
166+
return requestFields;
167+
}
168+
160169
@Override
161170
protected boolean supportsConcurrentSegmentSearch() {
162171
return true;

server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java

Lines changed: 160 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@
2323
import org.opensearch.core.common.io.stream.StreamInput;
2424
import org.opensearch.core.common.io.stream.StreamOutput;
2525
import org.opensearch.core.common.io.stream.Writeable;
26+
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
27+
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
28+
import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator;
29+
import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.StarTreeValuesIterator;
2630
import org.opensearch.index.fielddata.SortedBinaryDocValues;
2731
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
32+
import org.opensearch.index.mapper.NumberFieldMapper;
2833
import org.opensearch.search.DocValueFormat;
2934
import org.opensearch.search.aggregations.Aggregator;
3035
import org.opensearch.search.aggregations.AggregatorFactories;
@@ -33,12 +38,17 @@
3338
import org.opensearch.search.aggregations.InternalAggregation;
3439
import org.opensearch.search.aggregations.InternalOrder;
3540
import org.opensearch.search.aggregations.LeafBucketCollector;
41+
import org.opensearch.search.aggregations.StarTreeBucketCollector;
42+
import org.opensearch.search.aggregations.StarTreePreComputeCollector;
3643
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
3744
import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator;
3845
import org.opensearch.search.aggregations.bucket.LocalBucketCountThresholds;
3946
import org.opensearch.search.aggregations.support.AggregationPath;
4047
import org.opensearch.search.aggregations.support.ValuesSource;
4148
import org.opensearch.search.internal.SearchContext;
49+
import org.opensearch.search.startree.StarTreeQueryHelper;
50+
import org.opensearch.search.startree.filter.DimensionFilter;
51+
import org.opensearch.search.startree.filter.MatchAllFilter;
4252

4353
import java.io.IOException;
4454
import java.math.BigInteger;
@@ -50,32 +60,38 @@
5060
import java.util.List;
5161
import java.util.Map;
5262
import java.util.Set;
63+
import java.util.function.Function;
5364

5465
import static org.opensearch.search.aggregations.InternalOrder.isKeyOrder;
5566
import static org.opensearch.search.aggregations.bucket.terms.TermsAggregator.descendsFromNestedAggregator;
67+
import static org.opensearch.search.startree.StarTreeQueryHelper.getSupportedStarTree;
5668

5769
/**
5870
* An aggregator that aggregate with multi_terms.
5971
*
6072
* @opensearch.internal
6173
*/
62-
public class MultiTermsAggregator extends DeferableBucketAggregator {
74+
public class MultiTermsAggregator extends DeferableBucketAggregator implements StarTreePreComputeCollector {
6375

6476
private final BytesKeyedBucketOrds bucketOrds;
6577
private final MultiTermsValuesSource multiTermsValue;
6678
private final boolean showTermDocCountError;
6779
private final List<DocValueFormat> formats;
80+
private final List<String> fields;
6881
private final TermsAggregator.BucketCountThresholds bucketCountThresholds;
6982
private final BucketOrder order;
7083
private final Comparator<InternalMultiTerms.Bucket> partiallyBuiltBucketComparator;
7184
private final SubAggCollectionMode collectMode;
7285
private final Set<Aggregator> aggsUsedForSorting = new HashSet<>();
86+
private final BytesStreamOutput starTreeScratch = new BytesStreamOutput();
7387

7488
public MultiTermsAggregator(
7589
String name,
7690
AggregatorFactories factories,
7791
boolean showTermDocCountError,
92+
List<ValuesSource> rawValuesSources,
7893
List<InternalValuesSource> internalValuesSources,
94+
List<String> fields,
7995
List<DocValueFormat> formats,
8096
BucketOrder order,
8197
SubAggCollectionMode collectMode,
@@ -87,7 +103,7 @@ public MultiTermsAggregator(
87103
) throws IOException {
88104
super(name, factories, context, parent, metadata);
89105
this.bucketOrds = BytesKeyedBucketOrds.build(context.bigArrays(), cardinality);
90-
this.multiTermsValue = new MultiTermsValuesSource(internalValuesSources);
106+
this.multiTermsValue = new MultiTermsValuesSource(rawValuesSources, internalValuesSources);
91107
this.showTermDocCountError = showTermDocCountError;
92108
this.formats = formats;
93109
this.bucketCountThresholds = bucketCountThresholds;
@@ -104,12 +120,12 @@ public MultiTermsAggregator(
104120
} else {
105121
this.collectMode = collectMode;
106122
}
123+
this.fields = fields;
107124
// Don't defer any child agg if we are dependent on it for pruning results
108125
if (order instanceof InternalOrder.Aggregation) {
109126
AggregationPath path = ((InternalOrder.Aggregation) order).path();
110127
aggsUsedForSorting.add(path.resolveTopmostAggregator(this));
111-
} else if (order instanceof InternalOrder.CompoundOrder) {
112-
InternalOrder.CompoundOrder compoundOrder = (InternalOrder.CompoundOrder) order;
128+
} else if (order instanceof InternalOrder.CompoundOrder compoundOrder) {
113129
for (BucketOrder orderElement : compoundOrder.orderElements()) {
114130
if (orderElement instanceof InternalOrder.Aggregation) {
115131
AggregationPath path = ((InternalOrder.Aggregation) orderElement).path();
@@ -226,6 +242,142 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
226242
};
227243
}
228244

245+
@Override
246+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
247+
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
248+
if (supportedStarTree != null) {
249+
preComputeWithStarTree(ctx, supportedStarTree);
250+
return true;
251+
}
252+
return false;
253+
}
254+
255+
private void preComputeWithStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
256+
StarTreeBucketCollector starTreeBucketCollector = getStarTreeBucketCollector(ctx, starTree, null);
257+
StarTreeQueryHelper.preComputeBucketsWithStarTree(starTreeBucketCollector);
258+
}
259+
260+
/**
261+
* Creates a {@link StarTreeBucketCollector} for pre-aggregating with a star-tree index.
262+
* This collector generates the cartesian product of dimension values within a single star-tree entry
263+
* to form the composite keys for the multi-terms aggregation.
264+
*/
265+
public StarTreeBucketCollector getStarTreeBucketCollector(
266+
LeafReaderContext ctx,
267+
CompositeIndexFieldInfo starTree,
268+
StarTreeBucketCollector parent
269+
) throws IOException {
270+
StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree);
271+
assert starTreeValues != null;
272+
SortedNumericStarTreeValuesIterator docCountsIterator = StarTreeQueryHelper.getDocCountsIterator(starTreeValues, starTree);
273+
274+
// Get an iterator for each field (dimension) in the multi-terms aggregation.
275+
final List<StarTreeValuesIterator> dimensionIterators = new ArrayList<>();
276+
// We also need a way to convert the raw long values from the iterators into the correct TermValue type.
277+
final List<Function<Long, TermValue<?>>> termValueBuilders = new ArrayList<>();
278+
279+
for (int i = 0; i < fields.size(); i++) {
280+
String fieldName = fields.get(i);
281+
dimensionIterators.add(starTreeValues.getDimensionValuesIterator(fieldName));
282+
ValuesSource vs = multiTermsValue.rawValueSources.get(i);
283+
284+
if (vs instanceof ValuesSource.Bytes.WithOrdinals vsBytes) {
285+
termValueBuilders.add(ord -> {
286+
try {
287+
return TermValue.of(vsBytes.globalOrdinalsValues(ctx).lookupOrd(ord));
288+
} catch (IOException e) {
289+
throw new RuntimeException(e);
290+
}
291+
292+
});
293+
} else if (vs instanceof ValuesSource.Numeric numericSource) {
294+
if (numericSource.isFloatingPoint()) {
295+
NumberFieldMapper.NumberFieldType numberFieldType = ((NumberFieldMapper.NumberFieldType) context.mapperService()
296+
.fieldType(fieldName));
297+
termValueBuilders.add(val -> TermValue.of(numberFieldType.toDoubleValue(val)));
298+
} else {
299+
termValueBuilders.add(TermValue::of);
300+
}
301+
} else {
302+
throw new IllegalStateException("Unsupported ValuesSource type for star-tree: " + vs.getClass().getName());
303+
}
304+
305+
}
306+
307+
return new StarTreeBucketCollector(
308+
starTreeValues,
309+
parent == null ? StarTreeQueryHelper.getStarTreeResult(starTreeValues, context, getDimensionFilters()) : null
310+
) {
311+
@Override
312+
public void setSubCollectors() throws IOException {
313+
for (Aggregator aggregator : subAggregators) {
314+
if (aggregator instanceof StarTreePreComputeCollector collector) {
315+
this.subCollectors.add(collector.getStarTreeBucketCollector(ctx, starTree, this));
316+
}
317+
}
318+
}
319+
320+
@Override
321+
public void collectStarTreeEntry(int starTreeEntry, long owningBucketOrd) throws IOException {
322+
if (docCountsIterator.advanceExact(starTreeEntry) == false) {
323+
return; // No documents in this star-tree entry.
324+
}
325+
long docCountMetric = docCountsIterator.nextValue();
326+
327+
List<List<TermValue<?>>> collectedValues = new ArrayList<>();
328+
for (int i = 0; i < dimensionIterators.size(); i++) {
329+
StarTreeValuesIterator dimIterator = dimensionIterators.get(i);
330+
if (!dimIterator.advanceExact(starTreeEntry)) {
331+
// If any dimension is missing for this entry, the cartesian product is empty.
332+
return;
333+
}
334+
335+
List<TermValue<?>> valuesForDim = new ArrayList<>();
336+
Function<Long, TermValue<?>> builder = termValueBuilders.get(i);
337+
for (int j = 0; j < dimIterator.entryValueCount(); j++) {
338+
valuesForDim.add(builder.apply(dimIterator.value()));
339+
}
340+
collectedValues.add(valuesForDim);
341+
}
342+
343+
starTreeScratch.seek(0);
344+
starTreeScratch.writeVInt(dimensionIterators.size());
345+
generateAndCollectFromStarTree(collectedValues, 0, owningBucketOrd, starTreeEntry, docCountMetric);
346+
}
347+
348+
private void generateAndCollectFromStarTree(
349+
List<List<TermValue<?>>> collectedValues,
350+
int index,
351+
long owningBucketOrd,
352+
int starTreeEntry,
353+
long docCountMetric
354+
) throws IOException {
355+
if (index == collectedValues.size()) {
356+
// A full composite key is in the buffer, add it to bucketOrds.
357+
long bucketOrd = bucketOrds.add(owningBucketOrd, starTreeScratch.bytes().toBytesRef());
358+
collectStarTreeBucket(this, docCountMetric, bucketOrd, starTreeEntry);
359+
return;
360+
}
361+
362+
long position = starTreeScratch.position();
363+
List<TermValue<?>> values = collectedValues.get(index);
364+
for (TermValue<?> value : values) {
365+
value.writeTo(starTreeScratch);
366+
generateAndCollectFromStarTree(collectedValues, index + 1, owningBucketOrd, starTreeEntry, docCountMetric);
367+
starTreeScratch.seek(position);
368+
}
369+
}
370+
};
371+
}
372+
373+
@Override
374+
public List<DimensionFilter> getDimensionFilters() {
375+
return StarTreeQueryHelper.collectDimensionFilters(
376+
fields.stream().map(a -> (DimensionFilter) new MatchAllFilter(a)).toList(),
377+
subAggregators
378+
);
379+
}
380+
229381
@Override
230382
protected void doClose() {
231383
Releasables.close(bucketOrds, multiTermsValue);
@@ -347,10 +499,12 @@ public static TermValue<Double> of(Double value) {
347499
* @opensearch.internal
348500
*/
349501
static class MultiTermsValuesSource implements Releasable {
502+
private final List<ValuesSource> rawValueSources;
350503
private final List<InternalValuesSource> valuesSources;
351504
private final BytesStreamOutput scratch = new BytesStreamOutput();
352505

353-
public MultiTermsValuesSource(List<InternalValuesSource> valuesSources) {
506+
public MultiTermsValuesSource(List<ValuesSource> rawValueSources, List<InternalValuesSource> valuesSources) {
507+
this.rawValueSources = rawValueSources;
354508
this.valuesSources = valuesSources;
355509
}
356510

@@ -416,8 +570,7 @@ private void generateAndCollectCompositeKeys(
416570
int numIterations = values.size();
417571
// For each loop is not done to reduce the allocations done for Iterator objects
418572
// once for every field in every doc.
419-
for (int i = 0; i < numIterations; i++) {
420-
TermValue<?> value = values.get(i);
573+
for (TermValue<?> value : values) {
421574
value.writeTo(scratch); // encode the value
422575
generateAndCollectCompositeKeys(collectedValues, index + 1, owningBucketOrd, doc); // dfs
423576
scratch.seek(position); // backtrack

0 commit comments

Comments
 (0)