7
7
8
8
import static java .util .concurrent .TimeUnit .SECONDS ;
9
9
import static org .opensearch .core .xcontent .XContentParserUtils .ensureExpectedToken ;
10
- import static org .opensearch .ml .common .CommonValue .stopWordsIndices ;
11
10
import static org .opensearch .ml .common .utils .StringUtils .gson ;
12
11
13
12
import java .io .IOException ;
25
24
import java .util .stream .Collectors ;
26
25
27
26
import org .opensearch .action .LatchedActionListener ;
28
- import org .opensearch .action .search .SearchRequest ;
29
27
import org .opensearch .action .search .SearchResponse ;
30
28
import org .opensearch .common .util .concurrent .ThreadContext ;
31
29
import org .opensearch .common .xcontent .LoggingDeprecationHandler ;
36
34
import org .opensearch .core .xcontent .NamedXContentRegistry ;
37
35
import org .opensearch .core .xcontent .XContentBuilder ;
38
36
import org .opensearch .core .xcontent .XContentParser ;
37
+ import org .opensearch .remote .metadata .client .SdkClient ;
38
+ import org .opensearch .remote .metadata .client .SearchDataObjectRequest ;
39
+ import org .opensearch .remote .metadata .common .SdkClientUtils ;
39
40
import org .opensearch .search .builder .SearchSourceBuilder ;
40
41
import org .opensearch .transport .client .Client ;
41
42
@@ -58,6 +59,8 @@ public class LocalRegexGuardrail extends Guardrail {
58
59
private Map <String , List <String >> stopWordsIndicesInput ;
59
60
private NamedXContentRegistry xContentRegistry ;
60
61
private Client client ;
62
+ private SdkClient sdkClient ;
63
+ private String tenantId ;
61
64
62
65
@ Builder (toBuilder = true )
63
66
public LocalRegexGuardrail (List <StopWords > stopWords , String [] regex ) {
@@ -109,9 +112,11 @@ public Boolean validate(String input, Map<String, String> parameters) {
109
112
}
110
113
111
114
@ Override
112
- public void init (NamedXContentRegistry xContentRegistry , Client client ) {
115
+ public void init (NamedXContentRegistry xContentRegistry , Client client , SdkClient sdkClient , String tenantId ) {
113
116
this .xContentRegistry = xContentRegistry ;
114
117
this .client = client ;
118
+ this .sdkClient = sdkClient ;
119
+ this .tenantId = tenantId ;
115
120
init ();
116
121
}
117
122
@@ -211,55 +216,34 @@ public Boolean validateStopWords(String input, Map<String, List<String>> stopWor
211
216
* @return true if no stop words matching, otherwise false.
212
217
*/
213
218
public Boolean validateStopWordsSingleIndex (String input , String indexName , List <String > fieldNames ) {
214
- SearchRequest searchRequest ;
215
- AtomicBoolean hitStopWords = new AtomicBoolean (false );
219
+ AtomicBoolean passedStopWordCheck = new AtomicBoolean (false );
216
220
String queryBody ;
217
221
Map <String , String > documentMap = new HashMap <>();
218
222
for (String field : fieldNames ) {
219
223
documentMap .put (field , input );
220
224
}
221
225
Map <String , Object > queryBodyMap = Map .of ("query" , Map .of ("percolate" , Map .of ("field" , "query" , "document" , documentMap )));
222
226
CountDownLatch latch = new CountDownLatch (1 );
223
- ThreadContext .StoredContext context = null ;
224
-
225
227
try {
226
228
queryBody = AccessController .doPrivileged ((PrivilegedExceptionAction <String >) () -> gson .toJson (queryBodyMap ));
227
- SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
228
- XContentParser queryParser = XContentType .JSON
229
- .xContent ()
230
- .createParser (xContentRegistry , LoggingDeprecationHandler .INSTANCE , queryBody );
231
- searchSourceBuilder .parseXContent (queryParser );
232
- searchSourceBuilder .size (1 ); // Only need 1 doc returned, if hit.
233
- searchRequest = new SearchRequest ().source (searchSourceBuilder ).indices (indexName );
234
- if (isStopWordsSystemIndex (indexName )) {
235
- context = client .threadPool ().getThreadContext ().stashContext ();
236
- ThreadContext .StoredContext finalContext = context ;
237
- client .search (searchRequest , ActionListener .runBefore (new LatchedActionListener (ActionListener .<SearchResponse >wrap (r -> {
238
- if (r == null || r .getHits () == null || r .getHits ().getTotalHits () == null || r .getHits ().getTotalHits ().value () == 0 ) {
239
- hitStopWords .set (true );
240
- }
241
- }, e -> {
242
- log .error ("Failed to search stop words index {}" , indexName , e );
243
- hitStopWords .set (true );
244
- }), latch ), () -> finalContext .restore ()));
245
- } else {
246
- client .search (searchRequest , new LatchedActionListener (ActionListener .<SearchResponse >wrap (r -> {
247
- if (r == null || r .getHits () == null || r .getHits ().getTotalHits () == null || r .getHits ().getTotalHits ().value () == 0 ) {
248
- hitStopWords .set (true );
249
- }
250
- }, e -> {
251
- log .error ("Failed to search stop words index {}" , indexName , e );
252
- hitStopWords .set (true );
253
- }), latch ));
229
+ SearchDataObjectRequest searchDataObjectRequest = buildSearchDataObjectRequest (indexName , queryBody );
230
+ var responseListener = new LatchedActionListener <>(ActionListener .<SearchResponse >wrap (r -> {
231
+ if (r == null || r .getHits () == null || r .getHits ().getTotalHits () == null || r .getHits ().getTotalHits ().value () == 0 ) {
232
+ passedStopWordCheck .set (true );
233
+ }
234
+ }, e -> {
235
+ log .error ("Failed to search stop words index {}" , indexName , e );
236
+ passedStopWordCheck .set (true );
237
+ }), latch );
238
+ try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
239
+ sdkClient
240
+ .searchDataObjectAsync (searchDataObjectRequest )
241
+ .whenComplete (SdkClientUtils .wrapSearchCompletion (ActionListener .runBefore (responseListener , context ::restore )));
254
242
}
255
243
} catch (Exception e ) {
256
244
log .error ("[validateStopWords] Searching stop words index failed." , e );
257
245
latch .countDown ();
258
- hitStopWords .set (true );
259
- } finally {
260
- if (context != null ) {
261
- context .close ();
262
- }
246
+ passedStopWordCheck .set (true );
263
247
}
264
248
265
249
try {
@@ -268,10 +252,17 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
268
252
log .error ("[validateStopWords] Searching stop words index was timeout." , e );
269
253
throw new IllegalStateException (e );
270
254
}
271
- return hitStopWords .get ();
255
+ return passedStopWordCheck .get ();
272
256
}
273
257
274
- private boolean isStopWordsSystemIndex (String index ) {
275
- return stopWordsIndices .contains (index );
258
+ protected SearchDataObjectRequest buildSearchDataObjectRequest (String indexName , String queryBody ) throws IOException {
259
+ SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
260
+ XContentParser queryParser = XContentType .JSON
261
+ .xContent ()
262
+ .createParser (xContentRegistry , LoggingDeprecationHandler .INSTANCE , queryBody );
263
+ searchSourceBuilder .parseXContent (queryParser );
264
+ searchSourceBuilder .size (1 ); // Only need 1 doc returned, if hit.
265
+
266
+ return SearchDataObjectRequest .builder ().indices (indexName ).searchSourceBuilder (searchSourceBuilder ).tenantId (tenantId ).build ();
276
267
}
277
268
}
0 commit comments