diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c3b10defac77..723bbcc0002da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add a dynamic setting to change skip_cache_factor and min_frequency for querycache ([#18351](https://github.com/opensearch-project/OpenSearch/issues/18351)) - Add overload constructor for Translog to accept Channel Factory as a parameter ([#18918](https://github.com/opensearch-project/OpenSearch/pull/18918)) - Add subdirectory-aware store module with recovery support ([#19132](https://github.com/opensearch-project/OpenSearch/pull/19132)) +- [Rule-based Auto-tagging] Add autotagging label resolving logic for multiple attributes ([#19424](https://github.com/opensearch-project/OpenSearch/pull/19424)) - Field collapsing supports search_after ([#19261](https://github.com/opensearch-project/OpenSearch/pull/19261)) - Add a dynamic cluster setting to control the enablement of the merged segment warmer ([#18929](https://github.com/opensearch-project/OpenSearch/pull/18929)) - Publish transport-grpc-spi exposing QueryBuilderProtoConverter and QueryBuilderProtoConverterRegistry ([#18949](https://github.com/opensearch-project/OpenSearch/pull/18949)) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 4bfed7ad4b27c..0695a0a9515a1 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -56,6 +56,7 @@ commonscodec = "1.18.0" commonslang = "3.18.0" commonscompress = "1.28.0" commonsio = "2.16.0" +commonscollections4 = "4.5.0" # plugin dependencies aws = "2.32.29" awscrt = "0.35.0" diff --git a/modules/autotagging-commons/common/build.gradle b/modules/autotagging-commons/common/build.gradle index 0dffb80015647..6b851d1974b4c 100644 --- a/modules/autotagging-commons/common/build.gradle +++ b/modules/autotagging-commons/common/build.gradle @@ -12,7 +12,7 @@ apply plugin: 'opensearch.publish' description = 'OpenSearch Rule framework common constructs which spi and module shares' dependencies { - api 'org.apache.commons:commons-collections4:4.4' + api "org.apache.commons:commons-collections4:${versions.commonscollections4}" implementation project(":libs:opensearch-common") compileOnly project(":server") diff --git a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/attribute_extractor/AttributeExtractor.java b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/attribute_extractor/AttributeExtractor.java index 186211c65a76e..0bfc9e60ed8b1 100644 --- a/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/attribute_extractor/AttributeExtractor.java +++ b/modules/autotagging-commons/common/src/main/java/org/opensearch/rule/attribute_extractor/AttributeExtractor.java @@ -15,6 +15,22 @@ * @param */ public interface AttributeExtractor { + + /** + * Defines the combination style used when a request contains multiple values + * for an attribute. + */ + enum LogicalOperator { + /** + * Logical AND + */ + AND, + /** + * Logical OR + */ + OR + } + /** * This method returns the Attribute which it is responsible for extracting * @return attribute @@ -26,4 +42,13 @@ public interface AttributeExtractor { * @return attribute value */ Iterable extract(); + + /** + * Returns the logical operator used when a request contains multiple values + * for an attribute. + * For example, if the request targets both index A and B, then a rule must + * have both index A and B as attributes, requiring an AND operator. + * @return the logical operator (e.g., AND, OR) + */ + LogicalOperator getLogicalOperator(); } diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/InMemoryRuleProcessingService.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/InMemoryRuleProcessingService.java index 7cf8b3bf8daec..5397f9f12915a 100644 --- a/modules/autotagging-commons/src/main/java/org/opensearch/rule/InMemoryRuleProcessingService.java +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/InMemoryRuleProcessingService.java @@ -11,9 +11,11 @@ import org.opensearch.rule.attribute_extractor.AttributeExtractor; import org.opensearch.rule.autotagging.Attribute; import org.opensearch.rule.autotagging.Rule; +import org.opensearch.rule.feature_value_resolver.FeatureValueResolver; import org.opensearch.rule.storage.AttributeValueStore; import org.opensearch.rule.storage.AttributeValueStoreFactory; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -32,13 +34,23 @@ public class InMemoryRuleProcessingService { */ public static final String WILDCARD = "*"; private final AttributeValueStoreFactory attributeValueStoreFactory; + /** + * Map of prioritized attributes + */ + private final Map prioritizedAttributes; /** - * Constructor - * @param attributeValueStoreFactory + * Constructs an InMemoryRuleProcessingService with the given + * attribute value store factory and a prioritized list of attributes. + * @param attributeValueStoreFactory Factory to create attribute value stores. + * @param prioritizedAttributes Map of prioritized attributes */ - public InMemoryRuleProcessingService(AttributeValueStoreFactory attributeValueStoreFactory) { + public InMemoryRuleProcessingService( + AttributeValueStoreFactory attributeValueStoreFactory, + Map prioritizedAttributes + ) { this.attributeValueStoreFactory = attributeValueStoreFactory; + this.prioritizedAttributes = prioritizedAttributes; } /** @@ -58,8 +70,14 @@ public void remove(final Rule rule) { } private void perform(Rule rule, BiConsumer>, Rule> ruleOperation) { - for (Map.Entry> attributeEntry : rule.getAttributeMap().entrySet()) { - ruleOperation.accept(attributeEntry, rule); + for (Attribute attribute : rule.getFeatureType().getAllowedAttributesRegistry().values()) { + Set attributeValues; + if (rule.getAttributeMap().containsKey(attribute)) { + attributeValues = rule.getAttributeMap().get(attribute); + } else { + attributeValues = Set.of(""); + } + ruleOperation.accept(Map.entry(attribute, attributeValues), rule); } } @@ -78,37 +96,14 @@ private void addOperation(Map.Entry> attributeEntry, Rule } /** - * Evaluates the label for the current request. It finds the matches for each attribute value and then it is an - * intersection of all the matches - * @param attributeExtractors list of extractors which are used to get the attribute values to find the - * matching rule - * @return a label if there is unique label otherwise empty + * Determines the final feature value for the given request + * @param attributeExtractors list of attribute extractors */ public Optional evaluateLabel(List> attributeExtractors) { - assert attributeValueStoreFactory != null; - Optional result = Optional.empty(); - for (AttributeExtractor attributeExtractor : attributeExtractors) { - AttributeValueStore valueStore = attributeValueStoreFactory.getAttributeValueStore( - attributeExtractor.getAttribute() - ); - for (String value : attributeExtractor.extract()) { - List> candidateMatches = valueStore.getAll(value); - - if (candidateMatches == null || candidateMatches.isEmpty()) { - return Optional.empty(); - } - - Optional possibleMatch = candidateMatches.get(0).stream().findAny(); - if (result.isEmpty()) { - result = possibleMatch; - } else { - boolean isThePossibleMatchEqualResult = possibleMatch.get().equals(result.get()); - if (!isThePossibleMatchEqualResult) { - return Optional.empty(); - } - } - } - } - return result; + attributeExtractors.sort( + Comparator.comparingInt(extractor -> prioritizedAttributes.getOrDefault(extractor.getAttribute(), Integer.MAX_VALUE)) + ); + FeatureValueResolver featureValueResolver = new FeatureValueResolver(attributeValueStoreFactory); + return featureValueResolver.resolve(attributeExtractors).resolveLabel(); } } diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/CandidateFeatureValues.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/CandidateFeatureValues.java new file mode 100644 index 0000000000000..753579251d5fa --- /dev/null +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/CandidateFeatureValues.java @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.feature_value_resolver; + +import org.opensearch.rule.attribute_extractor.AttributeExtractor; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Represents candidate feature values for an attribute + */ +public class CandidateFeatureValues { + + /** + * A list of sets of candidate feature values collected for an attribute + * The list is ordered from the most specific match to less specific ones. For example: + * featureValues = [ {"a", "b"}, {"c"} ] + * Here, {"a", "b"} comes first because these feature values comes from rules with a more specific match + * e.g. A rule with "username|123" is a more specific match than "username|1" when querying "username|1234". + */ + private final List> featureValuesBySpecificity; + + /** + * A flattened set of all candidate values collected across all specificity levels. + * This set combines all values in 'featureValues' into a single collection for easy access + * and intersection computations. + */ + private final Set flattenedValues = new HashSet<>(); + + /** + * Maps each feature value to the index of its first occurrence set in 'featureValues'. + * This helps in tie-breaking: values appearing earlier in the list (i.e., more specific matches) + * are considered better matches when resolving the final label. + */ + private final Map firstOccurrenceIndex = new HashMap<>(); + + /** + * Constructs CandidateFeatureValues initialized with given list of value sets. + * @param initialValues List of sets of candidate values. + */ + public CandidateFeatureValues(List> initialValues) { + this.featureValuesBySpecificity = new ArrayList<>(initialValues); + for (int i = 0; i < featureValuesBySpecificity.size(); i++) { + for (String val : featureValuesBySpecificity.get(i)) { + flattenedValues.add(val); + firstOccurrenceIndex.putIfAbsent(val, i); + } + } + } + + /** + * flattenedValues getter + */ + public Set getFlattenedValues() { + return flattenedValues; + } + + /** + * firstOccurrenceIndex getter + * @param value + */ + public int getFirstOccurrenceIndex(String value) { + return firstOccurrenceIndex.getOrDefault(value, Integer.MAX_VALUE); + } + + /** + * Merges this CandidateFeatureValues with another based on the specified logical operator + * @param other Other CandidateFeatureValues to merge with. + * @param logicalOperator Logical operator (AND / OR) for merging. + */ + public CandidateFeatureValues merge(CandidateFeatureValues other, AttributeExtractor.LogicalOperator logicalOperator) { + return switch (logicalOperator) { + case AND -> mergeAnd(other); + case OR -> mergeOr(other); + }; + } + + private CandidateFeatureValues mergeOr(CandidateFeatureValues other) { + return mergeByIndex(this.featureValuesBySpecificity, other.featureValuesBySpecificity, null); + } + + private CandidateFeatureValues mergeAnd(CandidateFeatureValues other) { + Set elementsInThis = this.featureValuesBySpecificity.stream().flatMap(Set::stream).collect(Collectors.toSet()); + Set elementsInOther = other.featureValuesBySpecificity.stream().flatMap(Set::stream).collect(Collectors.toSet()); + + Set common = new HashSet<>(elementsInThis); + common.retainAll(elementsInOther); + + return mergeByIndex(this.featureValuesBySpecificity, other.featureValuesBySpecificity, common); + } + + private CandidateFeatureValues mergeByIndex(List> list1, List> list2, Set filterElements) { + List> result = new ArrayList<>(); + int max = Math.max(list1.size(), list2.size()); + + for (int i = 0; i < max; i++) { + Set merged = new HashSet<>(); + if (i < list1.size()) { + merged.addAll(list1.get(i)); + } + if (i < list2.size()) { + merged.addAll(list2.get(i)); + } + if (filterElements != null) { + merged.retainAll(filterElements); + } + if (!merged.isEmpty()) { + result.add(merged); + } + } + return new CandidateFeatureValues(result); + } + + @Override + public String toString() { + return "(" + "values=" + featureValuesBySpecificity + ')'; + } + + List> getFeatureValuesBySpecificity() { + return featureValuesBySpecificity; + } +} diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/FeatureValueCollector.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/FeatureValueCollector.java new file mode 100644 index 0000000000000..4b7e165cc62b4 --- /dev/null +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/FeatureValueCollector.java @@ -0,0 +1,67 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.feature_value_resolver; + +import org.opensearch.rule.attribute_extractor.AttributeExtractor; +import org.opensearch.rule.storage.AttributeValueStore; + +import java.util.List; +import java.util.Set; + +/** + * Collects candidate feature values for a specified subfield of a given attribute extractor. + * For example, the "principal" attribute may contain subfields such as "username" and "role": + * principal: { + * "username": ["alice", "bob"], + * "role": ["admin"] + * } + * If the attribute does not define any subfields, then the subfield name is represented + * by an empty string "" + */ +public class FeatureValueCollector { + + private final AttributeValueStore attributeValueStore; + private final AttributeExtractor attributeExtractor; + private final String subfield; + + /** + * Constructs a FeatureValueCollector with the given store, extractor, and subfield. + * @param attributeValueStore The store to retrieve candidate feature values from. + * @param attributeExtractor The extractor to extract attribute values. + * @param subfield The subfield attribute + */ + public FeatureValueCollector( + AttributeValueStore attributeValueStore, + AttributeExtractor attributeExtractor, + String subfield + ) { + this.attributeValueStore = attributeValueStore; + this.attributeExtractor = attributeExtractor; + this.subfield = subfield; + } + + /** + * Collects feature values for the subfield from the attribute extractor. + */ + public CandidateFeatureValues collect() { + CandidateFeatureValues result = null; + for (String value : attributeExtractor.extract()) { + if (value.startsWith(subfield)) { + List> candidateLabels = attributeValueStore.getAll(value); + CandidateFeatureValues candidateValues = new CandidateFeatureValues(candidateLabels); + if (result == null) { + result = candidateValues; + } else { + result = candidateValues.merge(result, attributeExtractor.getLogicalOperator()); + } + } + } + return result; + } +} diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/FeatureValueResolver.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/FeatureValueResolver.java new file mode 100644 index 0000000000000..c7f4fd47aea73 --- /dev/null +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/FeatureValueResolver.java @@ -0,0 +1,167 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.feature_value_resolver; + +import org.opensearch.rule.attribute_extractor.AttributeExtractor; +import org.opensearch.rule.autotagging.Attribute; +import org.opensearch.rule.storage.AttributeValueStore; +import org.opensearch.rule.storage.AttributeValueStoreFactory; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; + +/** + * This class is responsible for collecting candidate feature values + * from multiple {@link AttributeExtractor}s and determining the final feature value + * by computing the intersection of candidate feature values across all extractors. + * The workflow is as follows: + * Each AttributeExtractor is used to fetch candidate feature values for its attribute. + * Candidate feature values are collected for all extractors. + * An intersection of all candidate feature values is computed. + * The intersection is then reduced to a final feature value using tie-breaking logic. + */ +public class FeatureValueResolver { + private final AttributeValueStoreFactory storeFactory; + + /** + * Constructor for FeatureValueAggregator + * @param storeFactory + */ + public FeatureValueResolver(AttributeValueStoreFactory storeFactory) { + this.storeFactory = storeFactory; + } + + /** + * Key entry function for the class. + * This function collects candidate feature values from the given list of attribute extractors, + * returning the FeatureValueResolutionResult including all candidate values and + * their intersection. + * @param extractors list of attribute extractors to collect values from + */ + public FeatureValueResolutionResult resolve(List> extractors) { + List candidateFeatureValueList = new ArrayList<>(); + Set intersection = null; + + for (AttributeExtractor extractor : extractors) { + Set values = collectValuesForAttribute(extractor, candidateFeatureValueList); + + if (intersection == null) { + intersection = values; + } else { + intersection.retainAll(values); + } + if (intersection.isEmpty()) { + break; + } + } + + return new FeatureValueResolutionResult(candidateFeatureValueList, intersection); + } + + /** + * Collects candidate feature values for a single attribute extractor. + * @param extractor The attribute extractor to collect values for. + * @param candidateFeatureValueList List to which candidate values are added. + */ + private Set collectValuesForAttribute( + AttributeExtractor extractor, + List candidateFeatureValueList + ) { + Attribute attr = extractor.getAttribute(); + AttributeValueStore store = storeFactory.getAttributeValueStore(attr); + TreeMap subfields = attr.getPrioritizedSubfields(); + if (subfields.isEmpty()) { + subfields = new TreeMap<>(Map.of(1, "")); + } + + // Iterate through all the subfield attributes of the attribute extractor, and take the union of the collected + // feature values for each subfield attribute because the relationship between subfields is "OR". + // e.g. An request comes from username_a, who has role_b. Let's say rule A matches the request because it has + // attribute username_a, and Rule B matches because it has attribute role_b, then both rule A and rule B are + // qualified (OR relationship between feature values from different attribute) + Set res = new HashSet<>(); + for (Map.Entry subfield : subfields.entrySet()) { + FeatureValueCollector featureValueCollector = new FeatureValueCollector(store, extractor, subfield.getValue()); + CandidateFeatureValues valuesForSubfieldAttribute = featureValueCollector.collect(); + if (valuesForSubfieldAttribute != null) { + candidateFeatureValueList.add(valuesForSubfieldAttribute); + res.addAll(valuesForSubfieldAttribute.getFlattenedValues()); + } + } + return res; + } + + /** + * Encapsulates the result of feature value aggregation, including + * all candidate feature values and their intersection. + */ + public static class FeatureValueResolutionResult { + private final List candidateFeatureValuesList; + private final Set intersectedFeatureValues; + + /** + * Constructs an FeatureValueResolutionResult. + * @param candidateFeatureValuesList List of all candidate feature values collected. + * @param intersectedFeatureValues Set of values that are common to all candidates (intersection). + */ + public FeatureValueResolutionResult(List candidateFeatureValuesList, Set intersectedFeatureValues) { + this.candidateFeatureValuesList = candidateFeatureValuesList; + this.intersectedFeatureValues = intersectedFeatureValues; + } + + /** + * Resolves the final label (feature value), or empty if no label can be determined. + */ + public Optional resolveLabel() { + if (intersectedFeatureValues == null || intersectedFeatureValues.isEmpty()) { + return Optional.empty(); + } + if (intersectedFeatureValues.size() == 1) { + String res = intersectedFeatureValues.iterator().next(); + return Optional.of(res); + } + return breakTie(); + } + + /** + * Breaks ties among multiple candidate labels by examining the priorities and + * positions in the CandidateFeatureValues list. + */ + private Optional breakTie() { + Set remaining = new HashSet<>(intersectedFeatureValues); + for (CandidateFeatureValues values : candidateFeatureValuesList) { + String best = null; + int bestIndex = Integer.MAX_VALUE; + Set tied = new HashSet<>(); + for (String val : remaining) { + int index = values.getFirstOccurrenceIndex(val); + if (index < bestIndex) { + tied.clear(); + tied.add(val); + best = val; + bestIndex = index; + } else if (index == bestIndex) { + tied.add(val); + } + } + if (tied.size() == 1 && best != null) { + return Optional.of(best); + } + remaining = tied; + } + + return Optional.empty(); + } + } +} diff --git a/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/package-info.java b/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/package-info.java new file mode 100644 index 0000000000000..3131accfe967f --- /dev/null +++ b/modules/autotagging-commons/src/main/java/org/opensearch/rule/feature_value_resolver/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains classes to resolve feature value + */ +package org.opensearch.rule.feature_value_resolver; diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/InMemoryRuleProcessingServiceTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/InMemoryRuleProcessingServiceTests.java index 2a802e63a01c2..e326bc65a26b8 100644 --- a/modules/autotagging-commons/src/test/java/org/opensearch/rule/InMemoryRuleProcessingServiceTests.java +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/InMemoryRuleProcessingServiceTests.java @@ -17,11 +17,14 @@ import org.opensearch.rule.storage.DefaultAttributeValueStore; import org.opensearch.test.OpenSearchTestCase; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import static org.opensearch.rule.attribute_extractor.AttributeExtractor.LogicalOperator.OR; + public class InMemoryRuleProcessingServiceTests extends OpenSearchTestCase { InMemoryRuleProcessingService sut; @@ -31,7 +34,7 @@ public void setUp() throws Exception { WLMFeatureType.WLM, DefaultAttributeValueStore::new ); - sut = new InMemoryRuleProcessingService(attributeValueStoreFactory); + sut = new InMemoryRuleProcessingService(attributeValueStoreFactory, WLMFeatureType.WLM.getOrderedAttributes()); } public void testAdd() { @@ -122,7 +125,8 @@ private static Rule getRule(Set attributeValues, String label) { } private static List> getAttributeExtractors(List extractedAttributes) { - List> extractors = List.of(new AttributeExtractor() { + List> extractors = new ArrayList<>(); + extractors.add(new AttributeExtractor() { @Override public Attribute getAttribute() { return TestAttribute.TEST_ATTRIBUTE; @@ -132,6 +136,11 @@ public Attribute getAttribute() { public Iterable extract() { return extractedAttributes; } + + @Override + public LogicalOperator getLogicalOperator() { + return OR; + } }); return extractors; } diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/feature_value_resolver/CandidateFeatureValuesTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/feature_value_resolver/CandidateFeatureValuesTests.java new file mode 100644 index 0000000000000..1292c56493184 --- /dev/null +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/feature_value_resolver/CandidateFeatureValuesTests.java @@ -0,0 +1,92 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.feature_value_resolver; + +import org.opensearch.rule.attribute_extractor.AttributeExtractor; +import org.opensearch.rule.feature_value_resolver.FeatureValueResolver.FeatureValueResolutionResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +public class CandidateFeatureValuesTests extends OpenSearchTestCase { + + public void testFlattenedValuesAndFirstOccurrence() { + List> input = List.of(Set.of("A", "B"), Set.of("C"), Set.of("A", "D")); + + CandidateFeatureValues cfv = new CandidateFeatureValues(input); + assertEquals(Set.of("A", "B", "C", "D"), cfv.getFlattenedValues()); + assertEquals(0, cfv.getFirstOccurrenceIndex("A")); + assertEquals(1, cfv.getFirstOccurrenceIndex("C")); + assertEquals(2, cfv.getFirstOccurrenceIndex("D")); + assertEquals(Integer.MAX_VALUE, cfv.getFirstOccurrenceIndex("X")); + } + + public void testMergeOr() { + CandidateFeatureValues cfv1 = new CandidateFeatureValues(List.of(Set.of("A"), Set.of("B"))); + CandidateFeatureValues cfv2 = new CandidateFeatureValues(List.of(Set.of("C"), Set.of("D"), Set.of("E"))); + + CandidateFeatureValues merged = cfv1.merge(cfv2, AttributeExtractor.LogicalOperator.OR); + assertEquals( + Set.of("A", "C"), + merged.getFlattenedValues() + .stream() + .filter(v -> v.equals("A") || v.equals("C")) + .collect(HashSet::new, HashSet::add, HashSet::addAll) + ); + assertTrue(merged.getFlattenedValues().containsAll(Set.of("A", "B", "C", "D", "E"))); + } + + public void testMergeAnd() { + CandidateFeatureValues cfv1 = new CandidateFeatureValues(List.of(Set.of("A", "B"), Set.of("C"))); + CandidateFeatureValues cfv2 = new CandidateFeatureValues(List.of(Set.of("B", "C"), Set.of("C", "D"))); + + CandidateFeatureValues merged = cfv1.merge(cfv2, AttributeExtractor.LogicalOperator.AND); + assertTrue(merged.getFlattenedValues().containsAll(Set.of("B", "C"))); + assertFalse(merged.getFlattenedValues().contains("A")); + assertFalse(merged.getFlattenedValues().contains("D")); + } + + public void testResolveTieImmediateWinner() { + CandidateFeatureValues cfv1 = new CandidateFeatureValues(List.of(Set.of("A"), Set.of("B"))); + Set candidates = Set.of("A", "B"); + FeatureValueResolutionResult result = new FeatureValueResolutionResult(List.of(cfv1), candidates); + Optional winner = result.resolveLabel(); + assertTrue(winner.isPresent()); + assertEquals("A", winner.get()); + } + + public void testResolveTieResolvedInSecondIteration() { + CandidateFeatureValues cfv1 = new CandidateFeatureValues(List.of(Set.of("A", "B"))); + CandidateFeatureValues cfv2 = new CandidateFeatureValues(List.of(Set.of("B"), Set.of("A"))); + Set candidates = Set.of("A", "B"); + FeatureValueResolutionResult result = new FeatureValueResolutionResult(List.of(cfv1, cfv2), candidates); + Optional winner = result.resolveLabel(); + assertTrue(winner.isPresent()); + assertEquals("B", winner.get()); + } + + public void testResolveTieNoWinner() { + CandidateFeatureValues cfv1 = new CandidateFeatureValues(List.of(Set.of("A", "B"))); + CandidateFeatureValues cfv2 = new CandidateFeatureValues(List.of(Set.of("B", "A"))); + Set candidates = Set.of("A", "B"); + FeatureValueResolutionResult result = new FeatureValueResolutionResult(List.of(cfv1, cfv2), candidates); + Optional winner = result.resolveLabel(); + assertTrue(winner.isEmpty()); + } + + public void testToStringContainsValues() { + CandidateFeatureValues cfv = new CandidateFeatureValues(List.of(Set.of("A"))); + String str = cfv.toString(); + assertTrue(str.contains("A")); + assertTrue(str.contains("values=")); + } +} diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/feature_value_resolver/FeatureValueCollectorTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/feature_value_resolver/FeatureValueCollectorTests.java new file mode 100644 index 0000000000000..f37b13b6b779c --- /dev/null +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/feature_value_resolver/FeatureValueCollectorTests.java @@ -0,0 +1,90 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.feature_value_resolver; + +import org.opensearch.rule.attribute_extractor.AttributeExtractor; +import org.opensearch.rule.storage.AttributeValueStore; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Set; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@SuppressWarnings("unchecked") +public class FeatureValueCollectorTests extends OpenSearchTestCase { + + private AttributeValueStore attributeValueStore; + private AttributeExtractor attributeExtractor; + + public void setUp() throws Exception { + super.setUp(); + attributeValueStore = mock(AttributeValueStore.class); + attributeExtractor = mock(AttributeExtractor.class); + } + + public void testNoValuesExtractedReturnsNull() { + when(attributeExtractor.extract()).thenReturn(List.of()); + + FeatureValueCollector collector = + new FeatureValueCollector(attributeValueStore, attributeExtractor, "username"); + + CandidateFeatureValues result = collector.collect(); + + assertNull(result); + verifyNoInteractions(attributeValueStore); + } + + public void testSingleMatchingValueReturnsCandidateValues() { + when(attributeExtractor.extract()).thenReturn(List.of("username|alice")); + when(attributeValueStore.getAll("username|alice")).thenReturn(List.of(Set.of("label1", "label2"))); + + FeatureValueCollector collector = + new FeatureValueCollector(attributeValueStore, attributeExtractor, "username"); + + CandidateFeatureValues result = collector.collect(); + + assertNotNull(result); + assertEquals(1, result.getFeatureValuesBySpecificity().size()); + assertTrue(result.getFeatureValuesBySpecificity().get(0).contains("label1")); + assertTrue(result.getFeatureValuesBySpecificity().get(0).contains("label2")); + } + + public void testNonMatchingValuesAreIgnored() { + when(attributeExtractor.extract()).thenReturn(List.of("role|admin")); + + FeatureValueCollector collector = + new FeatureValueCollector(attributeValueStore, attributeExtractor, "username"); + + CandidateFeatureValues result = collector.collect(); + + assertNull(result); + verify(attributeValueStore, never()).get(any()); + } + + public void testMultipleMatchingValuesMerged() { + when(attributeExtractor.extract()).thenReturn(List.of("username|alice", "username|bob")); + when(attributeValueStore.getAll("username|alice")).thenReturn(List.of(Set.of("label1"))); + when(attributeValueStore.getAll("username|bob")).thenReturn(List.of(Set.of("label2"))); + when(attributeExtractor.getLogicalOperator()).thenReturn(AttributeExtractor.LogicalOperator.OR); + + FeatureValueCollector collector = + new FeatureValueCollector(attributeValueStore, attributeExtractor, "username"); + + CandidateFeatureValues result = collector.collect(); + + assertNotNull(result); + assertTrue(result.getFeatureValuesBySpecificity().stream().anyMatch(set -> set.contains("label1") || set.contains("label2"))); + } +} diff --git a/modules/autotagging-commons/src/test/java/org/opensearch/rule/feature_value_resolver/FeatureValueResolverTests.java b/modules/autotagging-commons/src/test/java/org/opensearch/rule/feature_value_resolver/FeatureValueResolverTests.java new file mode 100644 index 0000000000000..463c9575517b1 --- /dev/null +++ b/modules/autotagging-commons/src/test/java/org/opensearch/rule/feature_value_resolver/FeatureValueResolverTests.java @@ -0,0 +1,143 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.rule.feature_value_resolver; + +import org.opensearch.rule.attribute_extractor.AttributeExtractor; +import org.opensearch.rule.autotagging.Attribute; +import org.opensearch.rule.storage.AttributeValueStore; +import org.opensearch.rule.storage.AttributeValueStoreFactory; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@SuppressWarnings("unchecked") +public class FeatureValueResolverTests extends OpenSearchTestCase { + + private AttributeValueStoreFactory storeFactory; + private AttributeExtractor extractor; + private AttributeValueStore store; + private Attribute attribute; + + public void setUp() throws Exception { + super.setUp(); + storeFactory = mock(AttributeValueStoreFactory.class); + extractor = mock(AttributeExtractor.class); + store = mock(AttributeValueStore.class); + attribute = mock(Attribute.class); + } + + public void testResolveWithNoExtractorsReturnsEmptyIntersection() { + FeatureValueResolver resolver = new FeatureValueResolver(storeFactory); + FeatureValueResolver.FeatureValueResolutionResult result = resolver.resolve(List.of()); + assertNotNull(result); + assertTrue(result.resolveLabel().isEmpty()); + } + + public void testResolveSingleExtractorSingleSubfieldSingleValue() { + when(extractor.getAttribute()).thenReturn(attribute); + when(attribute.getPrioritizedSubfields()).thenReturn(new TreeMap<>(Map.of(1, "username"))); + when(storeFactory.getAttributeValueStore(attribute)).thenReturn(store); + + when(extractor.extract()).thenReturn(List.of("username|alice")); + when(store.getAll("username|alice")).thenReturn(List.of(Set.of("label1"))); + + FeatureValueResolver resolver = new FeatureValueResolver(storeFactory); + FeatureValueResolver.FeatureValueResolutionResult result = resolver.resolve(List.of(extractor)); + + assertNotNull(result); + Optional resolved = result.resolveLabel(); + assertTrue(resolved.isPresent()); + assertEquals("label1", resolved.get()); + } + + public void testResolveSingleExtractorWithNoSubfieldsDefaultsToEmptyString() { + when(extractor.getAttribute()).thenReturn(attribute); + when(attribute.getPrioritizedSubfields()).thenReturn(new TreeMap<>()); + when(storeFactory.getAttributeValueStore(attribute)).thenReturn(store); + + when(extractor.extract()).thenReturn(List.of("|value")); + when(store.getAll("|value")).thenReturn(List.of(Set.of("labelX"))); + + FeatureValueResolver resolver = new FeatureValueResolver(storeFactory); + FeatureValueResolver.FeatureValueResolutionResult result = resolver.resolve(List.of(extractor)); + + assertNotNull(result); + assertTrue(result.resolveLabel().isPresent()); + assertEquals("labelX", result.resolveLabel().get()); + } + + public void testResolveMultipleExtractorsIntersection() { + AttributeExtractor extractor1 = mock(AttributeExtractor.class); + AttributeExtractor extractor2 = mock(AttributeExtractor.class); + Attribute attr1 = mock(Attribute.class); + Attribute attr2 = mock(Attribute.class); + AttributeValueStore store1 = mock(AttributeValueStore.class); + AttributeValueStore store2 = mock(AttributeValueStore.class); + + when(extractor1.getAttribute()).thenReturn(attr1); + when(extractor2.getAttribute()).thenReturn(attr2); + + when(attr1.getPrioritizedSubfields()).thenReturn(new TreeMap<>(Map.of(1, "username"))); + when(attr2.getPrioritizedSubfields()).thenReturn(new TreeMap<>(Map.of(1, "role"))); + + when(storeFactory.getAttributeValueStore(attr1)).thenReturn(store1); + when(storeFactory.getAttributeValueStore(attr2)).thenReturn(store2); + + when(extractor1.extract()).thenReturn(List.of("username|alice")); + when(extractor2.extract()).thenReturn(List.of("role|admin")); + + when(store1.getAll("username|alice")).thenReturn(List.of(Set.of("label1", "common"))); + when(store2.getAll("role|admin")).thenReturn(List.of(Set.of("common", "label2"))); + + FeatureValueResolver resolver = new FeatureValueResolver(storeFactory); + FeatureValueResolver.FeatureValueResolutionResult result = resolver.resolve(List.of(extractor1, extractor2)); + + assertNotNull(result); + Optional resolved = result.resolveLabel(); + assertTrue(resolved.isPresent()); + assertEquals("common", resolved.get()); + } + + public void testResolveWithNoIntersectionReturnsEmpty() { + AttributeExtractor extractor1 = mock(AttributeExtractor.class); + AttributeExtractor extractor2 = mock(AttributeExtractor.class); + Attribute attr1 = mock(Attribute.class); + Attribute attr2 = mock(Attribute.class); + AttributeValueStore store1 = mock(AttributeValueStore.class); + AttributeValueStore store2 = mock(AttributeValueStore.class); + + when(extractor1.getAttribute()).thenReturn(attr1); + when(extractor2.getAttribute()).thenReturn(attr2); + + when(attr1.getPrioritizedSubfields()).thenReturn(new TreeMap<>(Map.of(1, "username"))); + when(attr2.getPrioritizedSubfields()).thenReturn(new TreeMap<>(Map.of(1, "role"))); + + when(storeFactory.getAttributeValueStore(attr1)).thenReturn(store1); + when(storeFactory.getAttributeValueStore(attr2)).thenReturn(store2); + + when(extractor1.extract()).thenReturn(List.of("username|alice")); + when(extractor2.extract()).thenReturn(List.of("role|admin")); + + when(store1.getAll("username|alice")).thenReturn(List.of(Set.of("label1"))); + when(store2.getAll("role|admin")).thenReturn(List.of(Set.of("label2"))); + + FeatureValueResolver resolver = new FeatureValueResolver(storeFactory); + FeatureValueResolver.FeatureValueResolutionResult result = resolver.resolve(List.of(extractor1, extractor2)); + + assertNotNull(result); + assertTrue(result.resolveLabel().isEmpty()); + } +} diff --git a/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java b/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java index 5bc5b6658f53d..797ed092df806 100644 --- a/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java +++ b/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java @@ -73,8 +73,10 @@ import org.opensearch.plugin.wlm.rule.sync.RefreshBasedSyncMechanism; import org.opensearch.plugin.wlm.rule.sync.detect.RuleEventClassifier; import org.opensearch.plugin.wlm.service.WorkloadGroupPersistenceService; +import org.opensearch.plugin.wlm.spi.AttributeExtractorExtension; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.DiscoveryPlugin; +import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SystemIndexPlugin; import org.opensearch.repositories.RepositoriesService; @@ -133,6 +135,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.plugin.wlm.WorkloadManagementPlugin.PRINCIPAL_ATTRIBUTE_NAME; import static org.opensearch.search.SearchService.CLUSTER_CONCURRENT_SEGMENT_SEARCH_SETTING; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; import static org.opensearch.threadpool.ThreadPool.Names.SAME; @@ -717,6 +720,7 @@ public static class TestWorkloadManagementPlugin extends Plugin ActionPlugin, SystemIndexPlugin, DiscoveryPlugin, + ExtensiblePlugin, RuleFrameworkExtension { /** @@ -729,8 +733,10 @@ public static class TestWorkloadManagementPlugin extends Plugin public static final int MAX_RULES_PER_PAGE = 50; static FeatureType featureType; static RulePersistenceService rulePersistenceService; + private static final Map orderedAttributes = new HashMap<>(); static RuleRoutingService ruleRoutingService; private AutoTaggingActionFilter autoTaggingActionFilter; + private final Map attributeExtractorExtensions = new HashMap<>(); /** * Default constructor. @@ -758,7 +764,10 @@ public Collection createComponents( featureType, DefaultAttributeValueStore::new ); - InMemoryRuleProcessingService ruleProcessingService = new InMemoryRuleProcessingService(attributeValueStoreFactory); + InMemoryRuleProcessingService ruleProcessingService = new InMemoryRuleProcessingService( + attributeValueStoreFactory, + featureType.getOrderedAttributes() + ); rulePersistenceService = new IndexStoredRulePersistenceService( INDEX_NAME, client, @@ -782,7 +791,13 @@ public Collection createComponents( wlmClusterSettingValuesProvider ); - autoTaggingActionFilter = new AutoTaggingActionFilter(ruleProcessingService, threadPool); + autoTaggingActionFilter = new AutoTaggingActionFilter( + ruleProcessingService, + threadPool, + attributeExtractorExtensions, + wlmClusterSettingValuesProvider, + featureType + ); return List.of(refreshMechanism); } @@ -865,6 +880,18 @@ public Supplier getFeatureTypeSupplier() { } @Override - public void setAttributes(List attributes) {} + public void setAttributes(List attributes) { + for (Attribute attribute : attributes) { + if (attribute.getName().equals(PRINCIPAL_ATTRIBUTE_NAME)) { + orderedAttributes.put(attribute, 1); + } + } + } + + public void loadExtensions(ExtensiblePlugin.ExtensionLoader loader) { + for (AttributeExtractorExtension ext : loader.loadExtensions(AttributeExtractorExtension.class)) { + attributeExtractorExtensions.put(ext.getAttributeExtractor().getAttribute(), ext); + } + } } } diff --git a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java index 1268f0f69b5eb..112e2a11956f0 100644 --- a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java +++ b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java @@ -16,29 +16,53 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.plugin.wlm.rule.attribute_extractor.IndicesExtractor; +import org.opensearch.plugin.wlm.spi.AttributeExtractorExtension; import org.opensearch.rule.InMemoryRuleProcessingService; +import org.opensearch.rule.attribute_extractor.AttributeExtractor; +import org.opensearch.rule.autotagging.Attribute; +import org.opensearch.rule.autotagging.FeatureType; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.wlm.WlmMode; import org.opensearch.wlm.WorkloadGroupTask; +import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; +import static org.opensearch.plugin.wlm.WorkloadManagementPlugin.PRINCIPAL_ATTRIBUTE_NAME; + /** * This class is responsible to evaluate and assign the WORKLOAD_GROUP_ID header in ThreadContext */ public class AutoTaggingActionFilter implements ActionFilter { private final InMemoryRuleProcessingService ruleProcessingService; - ThreadPool threadPool; + private final ThreadPool threadPool; + private final Map attributeExtensions; + private final WlmClusterSettingValuesProvider wlmClusterSettingValuesProvider; + private final FeatureType featureType; /** * Main constructor * @param ruleProcessingService provides access to in memory view of rules * @param threadPool to access assign the label + * @param attributeExtensions + * @param wlmClusterSettingValuesProvider + * @param featureType */ - public AutoTaggingActionFilter(InMemoryRuleProcessingService ruleProcessingService, ThreadPool threadPool) { + public AutoTaggingActionFilter( + InMemoryRuleProcessingService ruleProcessingService, + ThreadPool threadPool, + Map attributeExtensions, + WlmClusterSettingValuesProvider wlmClusterSettingValuesProvider, + FeatureType featureType + ) { this.ruleProcessingService = ruleProcessingService; this.threadPool = threadPool; + this.attributeExtensions = attributeExtensions; + this.wlmClusterSettingValuesProvider = wlmClusterSettingValuesProvider; + this.featureType = featureType; } @Override @@ -56,12 +80,20 @@ public void app ) { final boolean isValidRequest = request instanceof SearchRequest; - if (!isValidRequest) { + if (!isValidRequest || wlmClusterSettingValuesProvider.getWlmMode() == WlmMode.DISABLED) { chain.proceed(task, action, request, listener); return; } - Optional label = ruleProcessingService.evaluateLabel(List.of(new IndicesExtractor((IndicesRequest) request))); + List> attributeExtractors = new ArrayList<>(); + attributeExtractors.add(new IndicesExtractor((IndicesRequest) request)); + + if (featureType.getAllowedAttributesRegistry().containsKey(PRINCIPAL_ATTRIBUTE_NAME)) { + Attribute attribute = featureType.getAllowedAttributesRegistry().get(PRINCIPAL_ATTRIBUTE_NAME); + assert attributeExtensions.containsKey(attribute); + attributeExtractors.add(attributeExtensions.get(attribute).getAttributeExtractor()); + } + Optional label = ruleProcessingService.evaluateLabel(attributeExtractors); label.ifPresent(s -> threadPool.getThreadContext().putHeader(WorkloadGroupTask.WORKLOAD_GROUP_ID_HEADER, s)); chain.proceed(task, action, request, listener); } diff --git a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/WorkloadManagementPlugin.java b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/WorkloadManagementPlugin.java index 09a26444c09e1..d5bfede75926c 100644 --- a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/WorkloadManagementPlugin.java +++ b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/WorkloadManagementPlugin.java @@ -141,7 +141,10 @@ public Collection createComponents( featureType, DefaultAttributeValueStore::new ); - InMemoryRuleProcessingService ruleProcessingService = new InMemoryRuleProcessingService(attributeValueStoreFactory); + InMemoryRuleProcessingService ruleProcessingService = new InMemoryRuleProcessingService( + attributeValueStoreFactory, + featureType.getOrderedAttributes() + ); rulePersistenceService = new IndexStoredRulePersistenceService( INDEX_NAME, client, @@ -161,7 +164,13 @@ public Collection createComponents( wlmClusterSettingValuesProvider ); - autoTaggingActionFilter = new AutoTaggingActionFilter(ruleProcessingService, threadPool); + autoTaggingActionFilter = new AutoTaggingActionFilter( + ruleProcessingService, + threadPool, + attributeExtractorExtensions, + wlmClusterSettingValuesProvider, + featureType + ); return List.of(refreshMechanism); } diff --git a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/attribute_extractor/IndicesExtractor.java b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/attribute_extractor/IndicesExtractor.java index e556e2984777e..05057dbf41634 100644 --- a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/attribute_extractor/IndicesExtractor.java +++ b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/rule/attribute_extractor/IndicesExtractor.java @@ -15,6 +15,8 @@ import java.util.List; +import static org.opensearch.rule.attribute_extractor.AttributeExtractor.LogicalOperator.AND; + /** * This class extracts the indices from a request */ @@ -38,4 +40,9 @@ public Attribute getAttribute() { public Iterable extract() { return List.of(indicesRequest.indices()); } + + @Override + public LogicalOperator getLogicalOperator() { + return AND; + } } diff --git a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java index e5a0de1614852..d445d24d45f97 100644 --- a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java +++ b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java @@ -28,8 +28,10 @@ import org.opensearch.wlm.WorkloadGroupTask; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.TreeMap; import static org.mockito.Mockito.anyList; import static org.mockito.Mockito.mock; @@ -51,8 +53,14 @@ public void setUp() throws Exception { WLMFeatureType.WLM, DefaultAttributeValueStore::new ); - ruleProcessingService = spy(new InMemoryRuleProcessingService(attributeValueStoreFactory)); - autoTaggingActionFilter = new AutoTaggingActionFilter(ruleProcessingService, threadPool); + ruleProcessingService = spy(new InMemoryRuleProcessingService(attributeValueStoreFactory, null)); + autoTaggingActionFilter = new AutoTaggingActionFilter( + ruleProcessingService, + threadPool, + new HashMap<>(), + mock(WlmClusterSettingValuesProvider.class), + WLMFeatureType.WLM + ); } public void tearDown() throws Exception { @@ -114,6 +122,11 @@ public String getName() { return name; } + @Override + public TreeMap getPrioritizedSubfields() { + return new TreeMap<>(); + } + @Override public void validateAttribute() {}