Skip to content

Commit 70d8f18

Browse files
authored
Support multi-tenancy for LocalRegexGuardrail (opensearch-project#4120)
* Introduce sdk client to LocalRegexGuardrail Signed-off-by: Yuanchun Shen <[email protected]> * Use try-with-resource when validating stop words Signed-off-by: Yuanchun Shen <[email protected]> * Improve unit tests for LocalRegexGuardrail Signed-off-by: Yuanchun Shen <[email protected]> * Unit test failed cases for validateStopWordsSingleIndex Signed-off-by: Yuanchun Shen <[email protected]> --------- Signed-off-by: Yuanchun Shen <[email protected]>
1 parent 667be7e commit 70d8f18

File tree

9 files changed

+184
-146
lines changed

9 files changed

+184
-146
lines changed

common/src/main/java/org/opensearch/ml/common/model/Guardrail.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.opensearch.core.common.io.stream.StreamOutput;
1212
import org.opensearch.core.xcontent.NamedXContentRegistry;
1313
import org.opensearch.core.xcontent.ToXContentObject;
14+
import org.opensearch.remote.metadata.client.SdkClient;
1415
import org.opensearch.transport.client.Client;
1516

1617
public abstract class Guardrail implements ToXContentObject {
@@ -19,5 +20,5 @@ public abstract class Guardrail implements ToXContentObject {
1920

2021
public abstract Boolean validate(String input, Map<String, String> parameters);
2122

22-
public abstract void init(NamedXContentRegistry xContentRegistry, Client client);
23+
public abstract void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId);
2324
}

common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import static java.util.concurrent.TimeUnit.SECONDS;
99
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
10-
import static org.opensearch.ml.common.CommonValue.stopWordsIndices;
1110
import static org.opensearch.ml.common.utils.StringUtils.gson;
1211

1312
import java.io.IOException;
@@ -25,7 +24,6 @@
2524
import java.util.stream.Collectors;
2625

2726
import org.opensearch.action.LatchedActionListener;
28-
import org.opensearch.action.search.SearchRequest;
2927
import org.opensearch.action.search.SearchResponse;
3028
import org.opensearch.common.util.concurrent.ThreadContext;
3129
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
@@ -36,6 +34,9 @@
3634
import org.opensearch.core.xcontent.NamedXContentRegistry;
3735
import org.opensearch.core.xcontent.XContentBuilder;
3836
import org.opensearch.core.xcontent.XContentParser;
37+
import org.opensearch.remote.metadata.client.SdkClient;
38+
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
39+
import org.opensearch.remote.metadata.common.SdkClientUtils;
3940
import org.opensearch.search.builder.SearchSourceBuilder;
4041
import org.opensearch.transport.client.Client;
4142

@@ -58,6 +59,8 @@ public class LocalRegexGuardrail extends Guardrail {
5859
private Map<String, List<String>> stopWordsIndicesInput;
5960
private NamedXContentRegistry xContentRegistry;
6061
private Client client;
62+
private SdkClient sdkClient;
63+
private String tenantId;
6164

6265
@Builder(toBuilder = true)
6366
public LocalRegexGuardrail(List<StopWords> stopWords, String[] regex) {
@@ -109,9 +112,11 @@ public Boolean validate(String input, Map<String, String> parameters) {
109112
}
110113

111114
@Override
112-
public void init(NamedXContentRegistry xContentRegistry, Client client) {
115+
public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
113116
this.xContentRegistry = xContentRegistry;
114117
this.client = client;
118+
this.sdkClient = sdkClient;
119+
this.tenantId = tenantId;
115120
init();
116121
}
117122

@@ -211,55 +216,34 @@ public Boolean validateStopWords(String input, Map<String, List<String>> stopWor
211216
* @return true if no stop words matching, otherwise false.
212217
*/
213218
public Boolean validateStopWordsSingleIndex(String input, String indexName, List<String> fieldNames) {
214-
SearchRequest searchRequest;
215-
AtomicBoolean hitStopWords = new AtomicBoolean(false);
219+
AtomicBoolean passedStopWordCheck = new AtomicBoolean(false);
216220
String queryBody;
217221
Map<String, String> documentMap = new HashMap<>();
218222
for (String field : fieldNames) {
219223
documentMap.put(field, input);
220224
}
221225
Map<String, Object> queryBodyMap = Map.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
222226
CountDownLatch latch = new CountDownLatch(1);
223-
ThreadContext.StoredContext context = null;
224-
225227
try {
226228
queryBody = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBodyMap));
227-
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
228-
XContentParser queryParser = XContentType.JSON
229-
.xContent()
230-
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
231-
searchSourceBuilder.parseXContent(queryParser);
232-
searchSourceBuilder.size(1); // Only need 1 doc returned, if hit.
233-
searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName);
234-
if (isStopWordsSystemIndex(indexName)) {
235-
context = client.threadPool().getThreadContext().stashContext();
236-
ThreadContext.StoredContext finalContext = context;
237-
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
238-
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
239-
hitStopWords.set(true);
240-
}
241-
}, e -> {
242-
log.error("Failed to search stop words index {}", indexName, e);
243-
hitStopWords.set(true);
244-
}), latch), () -> finalContext.restore()));
245-
} else {
246-
client.search(searchRequest, new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
247-
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
248-
hitStopWords.set(true);
249-
}
250-
}, e -> {
251-
log.error("Failed to search stop words index {}", indexName, e);
252-
hitStopWords.set(true);
253-
}), latch));
229+
SearchDataObjectRequest searchDataObjectRequest = buildSearchDataObjectRequest(indexName, queryBody);
230+
var responseListener = new LatchedActionListener<>(ActionListener.<SearchResponse>wrap(r -> {
231+
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
232+
passedStopWordCheck.set(true);
233+
}
234+
}, e -> {
235+
log.error("Failed to search stop words index {}", indexName, e);
236+
passedStopWordCheck.set(true);
237+
}), latch);
238+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
239+
sdkClient
240+
.searchDataObjectAsync(searchDataObjectRequest)
241+
.whenComplete(SdkClientUtils.wrapSearchCompletion(ActionListener.runBefore(responseListener, context::restore)));
254242
}
255243
} catch (Exception e) {
256244
log.error("[validateStopWords] Searching stop words index failed.", e);
257245
latch.countDown();
258-
hitStopWords.set(true);
259-
} finally {
260-
if (context != null) {
261-
context.close();
262-
}
246+
passedStopWordCheck.set(true);
263247
}
264248

265249
try {
@@ -268,10 +252,17 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
268252
log.error("[validateStopWords] Searching stop words index was timeout.", e);
269253
throw new IllegalStateException(e);
270254
}
271-
return hitStopWords.get();
255+
return passedStopWordCheck.get();
272256
}
273257

274-
private boolean isStopWordsSystemIndex(String index) {
275-
return stopWordsIndices.contains(index);
258+
protected SearchDataObjectRequest buildSearchDataObjectRequest(String indexName, String queryBody) throws IOException {
259+
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
260+
XContentParser queryParser = XContentType.JSON
261+
.xContent()
262+
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
263+
searchSourceBuilder.parseXContent(queryParser);
264+
searchSourceBuilder.size(1); // Only need 1 doc returned, if hit.
265+
266+
return SearchDataObjectRequest.builder().indices(indexName).searchSourceBuilder(searchSourceBuilder).tenantId(tenantId).build();
276267
}
277268
}

common/src/main/java/org/opensearch/ml/common/model/MLGuard.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.Map;
99

1010
import org.opensearch.core.xcontent.NamedXContentRegistry;
11+
import org.opensearch.remote.metadata.client.SdkClient;
1112
import org.opensearch.transport.client.Client;
1213

1314
import lombok.Getter;
@@ -18,17 +19,21 @@
1819
public class MLGuard {
1920
private NamedXContentRegistry xContentRegistry;
2021
private Client client;
22+
private final SdkClient sdkClient;
23+
private final String tenantId;
2124
private Guardrails guardrails;
2225

23-
public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) {
26+
public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
2427
this.xContentRegistry = xContentRegistry;
2528
this.client = client;
29+
this.sdkClient = sdkClient;
30+
this.tenantId = tenantId;
2631
this.guardrails = guardrails;
2732
if (this.guardrails != null && this.guardrails.getInputGuardrail() != null) {
28-
this.guardrails.getInputGuardrail().init(xContentRegistry, client);
33+
this.guardrails.getInputGuardrail().init(xContentRegistry, client, sdkClient, tenantId);
2934
}
3035
if (this.guardrails != null && this.guardrails.getOutputGuardrail() != null) {
31-
this.guardrails.getOutputGuardrail().init(xContentRegistry, client);
36+
this.guardrails.getOutputGuardrail().init(xContentRegistry, client, sdkClient, tenantId);
3237
}
3338
}
3439

common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.opensearch.ml.common.transport.MLTaskResponse;
3838
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
3939
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
40+
import org.opensearch.remote.metadata.client.SdkClient;
4041
import org.opensearch.transport.client.Client;
4142

4243
import lombok.Builder;
@@ -58,6 +59,8 @@ public class ModelGuardrail extends Guardrail {
5859
private String responseAccept;
5960
private NamedXContentRegistry xContentRegistry;
6061
private Client client;
62+
private SdkClient sdkClient;
63+
private String tenantId;
6164
private Pattern regexAcceptPattern;
6265

6366
@Builder(toBuilder = true)
@@ -141,9 +144,11 @@ public Boolean validate(String in, Map<String, String> parameters) {
141144
}
142145

143146
@Override
144-
public void init(NamedXContentRegistry xContentRegistry, Client client) {
147+
public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
145148
this.xContentRegistry = xContentRegistry;
146149
this.client = client;
150+
this.sdkClient = sdkClient;
151+
this.tenantId = tenantId;
147152
regexAcceptPattern = Pattern.compile(responseAccept);
148153
}
149154

0 commit comments

Comments
 (0)