2323import org .opensearch .core .common .io .stream .StreamInput ;
2424import org .opensearch .core .common .io .stream .StreamOutput ;
2525import org .opensearch .core .common .io .stream .Writeable ;
26+ import org .opensearch .index .codec .composite .CompositeIndexFieldInfo ;
27+ import org .opensearch .index .compositeindex .datacube .startree .index .StarTreeValues ;
28+ import org .opensearch .index .compositeindex .datacube .startree .utils .iterator .SortedNumericStarTreeValuesIterator ;
29+ import org .opensearch .index .compositeindex .datacube .startree .utils .iterator .StarTreeValuesIterator ;
2630import org .opensearch .index .fielddata .SortedBinaryDocValues ;
2731import org .opensearch .index .fielddata .SortedNumericDoubleValues ;
32+ import org .opensearch .index .mapper .NumberFieldMapper ;
2833import org .opensearch .search .DocValueFormat ;
2934import org .opensearch .search .aggregations .Aggregator ;
3035import org .opensearch .search .aggregations .AggregatorFactories ;
3338import org .opensearch .search .aggregations .InternalAggregation ;
3439import org .opensearch .search .aggregations .InternalOrder ;
3540import org .opensearch .search .aggregations .LeafBucketCollector ;
41+ import org .opensearch .search .aggregations .StarTreeBucketCollector ;
42+ import org .opensearch .search .aggregations .StarTreePreComputeCollector ;
3643import org .opensearch .search .aggregations .bucket .BucketsAggregator ;
3744import org .opensearch .search .aggregations .bucket .DeferableBucketAggregator ;
3845import org .opensearch .search .aggregations .bucket .LocalBucketCountThresholds ;
3946import org .opensearch .search .aggregations .support .AggregationPath ;
4047import org .opensearch .search .aggregations .support .ValuesSource ;
4148import org .opensearch .search .internal .SearchContext ;
49+ import org .opensearch .search .startree .StarTreeQueryHelper ;
50+ import org .opensearch .search .startree .filter .DimensionFilter ;
51+ import org .opensearch .search .startree .filter .MatchAllFilter ;
4252
4353import java .io .IOException ;
4454import java .math .BigInteger ;
5060import java .util .List ;
5161import java .util .Map ;
5262import java .util .Set ;
63+ import java .util .function .Function ;
5364
5465import static org .opensearch .search .aggregations .InternalOrder .isKeyOrder ;
5566import static org .opensearch .search .aggregations .bucket .terms .TermsAggregator .descendsFromNestedAggregator ;
67+ import static org .opensearch .search .startree .StarTreeQueryHelper .getSupportedStarTree ;
5668
5769/**
5870 * An aggregator that aggregate with multi_terms.
5971 *
6072 * @opensearch.internal
6173 */
62- public class MultiTermsAggregator extends DeferableBucketAggregator {
74+ public class MultiTermsAggregator extends DeferableBucketAggregator implements StarTreePreComputeCollector {
6375
6476 private final BytesKeyedBucketOrds bucketOrds ;
6577 private final MultiTermsValuesSource multiTermsValue ;
6678 private final boolean showTermDocCountError ;
6779 private final List <DocValueFormat > formats ;
80+ private final List <String > fields ;
6881 private final TermsAggregator .BucketCountThresholds bucketCountThresholds ;
6982 private final BucketOrder order ;
7083 private final Comparator <InternalMultiTerms .Bucket > partiallyBuiltBucketComparator ;
7184 private final SubAggCollectionMode collectMode ;
7285 private final Set <Aggregator > aggsUsedForSorting = new HashSet <>();
86+ private final BytesStreamOutput starTreeScratch = new BytesStreamOutput ();
7387
7488 public MultiTermsAggregator (
7589 String name ,
7690 AggregatorFactories factories ,
7791 boolean showTermDocCountError ,
92+ List <ValuesSource > rawValuesSources ,
7893 List <InternalValuesSource > internalValuesSources ,
94+ List <String > fields ,
7995 List <DocValueFormat > formats ,
8096 BucketOrder order ,
8197 SubAggCollectionMode collectMode ,
@@ -87,7 +103,7 @@ public MultiTermsAggregator(
87103 ) throws IOException {
88104 super (name , factories , context , parent , metadata );
89105 this .bucketOrds = BytesKeyedBucketOrds .build (context .bigArrays (), cardinality );
90- this .multiTermsValue = new MultiTermsValuesSource (internalValuesSources );
106+ this .multiTermsValue = new MultiTermsValuesSource (rawValuesSources , internalValuesSources );
91107 this .showTermDocCountError = showTermDocCountError ;
92108 this .formats = formats ;
93109 this .bucketCountThresholds = bucketCountThresholds ;
@@ -104,12 +120,12 @@ public MultiTermsAggregator(
104120 } else {
105121 this .collectMode = collectMode ;
106122 }
123+ this .fields = fields ;
107124 // Don't defer any child agg if we are dependent on it for pruning results
108125 if (order instanceof InternalOrder .Aggregation ) {
109126 AggregationPath path = ((InternalOrder .Aggregation ) order ).path ();
110127 aggsUsedForSorting .add (path .resolveTopmostAggregator (this ));
111- } else if (order instanceof InternalOrder .CompoundOrder ) {
112- InternalOrder .CompoundOrder compoundOrder = (InternalOrder .CompoundOrder ) order ;
128+ } else if (order instanceof InternalOrder .CompoundOrder compoundOrder ) {
113129 for (BucketOrder orderElement : compoundOrder .orderElements ()) {
114130 if (orderElement instanceof InternalOrder .Aggregation ) {
115131 AggregationPath path = ((InternalOrder .Aggregation ) orderElement ).path ();
@@ -226,6 +242,142 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
226242 };
227243 }
228244
245+ @ Override
246+ protected boolean tryPrecomputeAggregationForLeaf (LeafReaderContext ctx ) throws IOException {
247+ CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree (this .context .getQueryShardContext ());
248+ if (supportedStarTree != null ) {
249+ preComputeWithStarTree (ctx , supportedStarTree );
250+ return true ;
251+ }
252+ return false ;
253+ }
254+
255+ private void preComputeWithStarTree (LeafReaderContext ctx , CompositeIndexFieldInfo starTree ) throws IOException {
256+ StarTreeBucketCollector starTreeBucketCollector = getStarTreeBucketCollector (ctx , starTree , null );
257+ StarTreeQueryHelper .preComputeBucketsWithStarTree (starTreeBucketCollector );
258+ }
259+
260+ /**
261+ * Creates a {@link StarTreeBucketCollector} for pre-aggregating with a star-tree index.
262+ * This collector generates the cartesian product of dimension values within a single star-tree entry
263+ * to form the composite keys for the multi-terms aggregation.
264+ */
265+ public StarTreeBucketCollector getStarTreeBucketCollector (
266+ LeafReaderContext ctx ,
267+ CompositeIndexFieldInfo starTree ,
268+ StarTreeBucketCollector parent
269+ ) throws IOException {
270+ StarTreeValues starTreeValues = StarTreeQueryHelper .getStarTreeValues (ctx , starTree );
271+ assert starTreeValues != null ;
272+ SortedNumericStarTreeValuesIterator docCountsIterator = StarTreeQueryHelper .getDocCountsIterator (starTreeValues , starTree );
273+
274+ // Get an iterator for each field (dimension) in the multi-terms aggregation.
275+ final List <StarTreeValuesIterator > dimensionIterators = new ArrayList <>();
276+ // We also need a way to convert the raw long values from the iterators into the correct TermValue type.
277+ final List <Function <Long , TermValue <?>>> termValueBuilders = new ArrayList <>();
278+
279+ for (int i = 0 ; i < fields .size (); i ++) {
280+ String fieldName = fields .get (i );
281+ dimensionIterators .add (starTreeValues .getDimensionValuesIterator (fieldName ));
282+ ValuesSource vs = multiTermsValue .rawValueSources .get (i );
283+
284+ if (vs instanceof ValuesSource .Bytes .WithOrdinals vsBytes ) {
285+ termValueBuilders .add (ord -> {
286+ try {
287+ return TermValue .of (vsBytes .globalOrdinalsValues (ctx ).lookupOrd (ord ));
288+ } catch (IOException e ) {
289+ throw new RuntimeException (e );
290+ }
291+
292+ });
293+ } else if (vs instanceof ValuesSource .Numeric numericSource ) {
294+ if (numericSource .isFloatingPoint ()) {
295+ NumberFieldMapper .NumberFieldType numberFieldType = ((NumberFieldMapper .NumberFieldType ) context .mapperService ()
296+ .fieldType (fieldName ));
297+ termValueBuilders .add (val -> TermValue .of (numberFieldType .toDoubleValue (val )));
298+ } else {
299+ termValueBuilders .add (TermValue ::of );
300+ }
301+ } else {
302+ throw new IllegalStateException ("Unsupported ValuesSource type for star-tree: " + vs .getClass ().getName ());
303+ }
304+
305+ }
306+
307+ return new StarTreeBucketCollector (
308+ starTreeValues ,
309+ parent == null ? StarTreeQueryHelper .getStarTreeResult (starTreeValues , context , getDimensionFilters ()) : null
310+ ) {
311+ @ Override
312+ public void setSubCollectors () throws IOException {
313+ for (Aggregator aggregator : subAggregators ) {
314+ if (aggregator instanceof StarTreePreComputeCollector collector ) {
315+ this .subCollectors .add (collector .getStarTreeBucketCollector (ctx , starTree , this ));
316+ }
317+ }
318+ }
319+
320+ @ Override
321+ public void collectStarTreeEntry (int starTreeEntry , long owningBucketOrd ) throws IOException {
322+ if (docCountsIterator .advanceExact (starTreeEntry ) == false ) {
323+ return ; // No documents in this star-tree entry.
324+ }
325+ long docCountMetric = docCountsIterator .nextValue ();
326+
327+ List <List <TermValue <?>>> collectedValues = new ArrayList <>();
328+ for (int i = 0 ; i < dimensionIterators .size (); i ++) {
329+ StarTreeValuesIterator dimIterator = dimensionIterators .get (i );
330+ if (!dimIterator .advanceExact (starTreeEntry )) {
331+ // If any dimension is missing for this entry, the cartesian product is empty.
332+ return ;
333+ }
334+
335+ List <TermValue <?>> valuesForDim = new ArrayList <>();
336+ Function <Long , TermValue <?>> builder = termValueBuilders .get (i );
337+ for (int j = 0 ; j < dimIterator .entryValueCount (); j ++) {
338+ valuesForDim .add (builder .apply (dimIterator .value ()));
339+ }
340+ collectedValues .add (valuesForDim );
341+ }
342+
343+ starTreeScratch .seek (0 );
344+ starTreeScratch .writeVInt (dimensionIterators .size ());
345+ generateAndCollectFromStarTree (collectedValues , 0 , owningBucketOrd , starTreeEntry , docCountMetric );
346+ }
347+
348+ private void generateAndCollectFromStarTree (
349+ List <List <TermValue <?>>> collectedValues ,
350+ int index ,
351+ long owningBucketOrd ,
352+ int starTreeEntry ,
353+ long docCountMetric
354+ ) throws IOException {
355+ if (index == collectedValues .size ()) {
356+ // A full composite key is in the buffer, add it to bucketOrds.
357+ long bucketOrd = bucketOrds .add (owningBucketOrd , starTreeScratch .bytes ().toBytesRef ());
358+ collectStarTreeBucket (this , docCountMetric , bucketOrd , starTreeEntry );
359+ return ;
360+ }
361+
362+ long position = starTreeScratch .position ();
363+ List <TermValue <?>> values = collectedValues .get (index );
364+ for (TermValue <?> value : values ) {
365+ value .writeTo (starTreeScratch );
366+ generateAndCollectFromStarTree (collectedValues , index + 1 , owningBucketOrd , starTreeEntry , docCountMetric );
367+ starTreeScratch .seek (position );
368+ }
369+ }
370+ };
371+ }
372+
373+ @ Override
374+ public List <DimensionFilter > getDimensionFilters () {
375+ return StarTreeQueryHelper .collectDimensionFilters (
376+ fields .stream ().map (a -> (DimensionFilter ) new MatchAllFilter (a )).toList (),
377+ subAggregators
378+ );
379+ }
380+
229381 @ Override
230382 protected void doClose () {
231383 Releasables .close (bucketOrds , multiTermsValue );
@@ -347,10 +499,12 @@ public static TermValue<Double> of(Double value) {
347499 * @opensearch.internal
348500 */
349501 static class MultiTermsValuesSource implements Releasable {
502+ private final List <ValuesSource > rawValueSources ;
350503 private final List <InternalValuesSource > valuesSources ;
351504 private final BytesStreamOutput scratch = new BytesStreamOutput ();
352505
353- public MultiTermsValuesSource (List <InternalValuesSource > valuesSources ) {
506+ public MultiTermsValuesSource (List <ValuesSource > rawValueSources , List <InternalValuesSource > valuesSources ) {
507+ this .rawValueSources = rawValueSources ;
354508 this .valuesSources = valuesSources ;
355509 }
356510
@@ -416,8 +570,7 @@ private void generateAndCollectCompositeKeys(
416570 int numIterations = values .size ();
417571 // For each loop is not done to reduce the allocations done for Iterator objects
418572 // once for every field in every doc.
419- for (int i = 0 ; i < numIterations ; i ++) {
420- TermValue <?> value = values .get (i );
573+ for (TermValue <?> value : values ) {
421574 value .writeTo (scratch ); // encode the value
422575 generateAndCollectCompositeKeys (collectedValues , index + 1 , owningBucketOrd , doc ); // dfs
423576 scratch .seek (position ); // backtrack
0 commit comments