Skip to content

Commit fa3ed30

Browse files
authored
add cmk role/assume role support in index insight (#4462)
* add cmk role/assume role support in index insight Signed-off-by: xinyual <[email protected]> * move location Signed-off-by: xinyual <[email protected]> * fix NPE problem Signed-off-by: xinyual <[email protected]> * add ut Signed-off-by: xinyual <[email protected]> * apply spotless Signed-off-by: xinyual <[email protected]> --------- Signed-off-by: xinyual <[email protected]>
1 parent b1f7291 commit fa3ed30

File tree

17 files changed

+199
-45
lines changed

17 files changed

+199
-45
lines changed

common/src/main/java/org/opensearch/ml/common/indexInsight/AbstractIndexInsightTask.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,23 @@ public abstract class AbstractIndexInsightTask implements IndexInsightTask {
7171
protected final String sourceIndex;
7272
protected final Client client;
7373
protected final SdkClient sdkClient;
74+
protected final String cmkRoleArn;
75+
protected final String cmkAssumeRoleArn;
7476

75-
protected AbstractIndexInsightTask(MLIndexInsightType taskType, String sourceIndex, Client client, SdkClient sdkClient) {
77+
protected AbstractIndexInsightTask(
78+
MLIndexInsightType taskType,
79+
String sourceIndex,
80+
Client client,
81+
SdkClient sdkClient,
82+
String cmkRoleArn,
83+
String cmkAssumeRoleArn
84+
) {
7685
this.taskType = taskType;
7786
this.sourceIndex = sourceIndex;
7887
this.client = client;
7988
this.sdkClient = sdkClient;
89+
this.cmkRoleArn = cmkRoleArn;
90+
this.cmkAssumeRoleArn = cmkAssumeRoleArn;
8091
}
8192

8293
/**
@@ -327,7 +338,14 @@ private void getIndexInsight(String docId, String tenantId, ActionListener<GetRe
327338
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
328339
sdkClient
329340
.getDataObjectAsync(
330-
GetDataObjectRequest.builder().tenantId(tenantId).index(ML_INDEX_INSIGHT_STORAGE_INDEX).id(docId).build()
341+
GetDataObjectRequest
342+
.builder()
343+
.tenantId(tenantId)
344+
.index(ML_INDEX_INSIGHT_STORAGE_INDEX)
345+
.id(docId)
346+
.cmkRoleArn(cmkRoleArn)
347+
.assumeRoleArn(cmkAssumeRoleArn)
348+
.build()
331349
)
332350
.whenComplete((r, throwable) -> {
333351
context.restore();
@@ -361,6 +379,8 @@ private void writeIndexInsight(IndexInsight indexInsight, String tenantId, Actio
361379
.index(ML_INDEX_INSIGHT_STORAGE_INDEX)
362380
.dataObject(indexInsight)
363381
.id(docId)
382+
.cmkRoleArn(cmkRoleArn)
383+
.assumeRoleArn(cmkAssumeRoleArn)
364384
.build()
365385
)
366386
.whenComplete((r, throwable) -> {

common/src/main/java/org/opensearch/ml/common/indexInsight/FieldDescriptionTask.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ public class FieldDescriptionTask extends AbstractIndexInsightTask {
4040

4141
private static final int BATCH_SIZE = 50; // Hard-coded value for now
4242

43-
public FieldDescriptionTask(String sourceIndex, Client client, SdkClient sdkClient) {
44-
super(MLIndexInsightType.FIELD_DESCRIPTION, sourceIndex, client, sdkClient);
43+
public FieldDescriptionTask(String sourceIndex, Client client, SdkClient sdkClient, String cmkRoleArn, String cmkAssumeRoleArn) {
44+
super(MLIndexInsightType.FIELD_DESCRIPTION, sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn);
4545
}
4646

4747
@Override
@@ -330,7 +330,7 @@ private Map<String, Object> parseFieldDescription(String modelResponse) {
330330
@Override
331331
public IndexInsightTask createPrerequisiteTask(MLIndexInsightType prerequisiteType) {
332332
if (prerequisiteType == MLIndexInsightType.STATISTICAL_DATA) {
333-
return new StatisticalDataTask(sourceIndex, client, sdkClient);
333+
return new StatisticalDataTask(sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn);
334334
}
335335
throw new IllegalStateException("Unsupported prerequisite type: " + prerequisiteType);
336336
}

common/src/main/java/org/opensearch/ml/common/indexInsight/LogRelatedIndexCheckTask.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ Your task is to analyze the structure and semantics of this index, and determine
8181
- Your judgment should be based on both semantics and field patterns (e.g., field names like "message", "log", "trace", "span", etc).
8282
""";
8383

84-
public LogRelatedIndexCheckTask(String sourceIndex, Client client, SdkClient sdkClient) {
85-
super(MLIndexInsightType.LOG_RELATED_INDEX_CHECK, sourceIndex, client, sdkClient);
84+
public LogRelatedIndexCheckTask(String sourceIndex, Client client, SdkClient sdkClient, String cmkRoleArn, String cmkAssumeRoleArn) {
85+
super(MLIndexInsightType.LOG_RELATED_INDEX_CHECK, sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn);
8686
}
8787

8888
@Override

common/src/main/java/org/opensearch/ml/common/indexInsight/StatisticalDataTask.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ public class StatisticalDataTask extends AbstractIndexInsightTask {
8888
detailed information: %s
8989
""";
9090

91-
public StatisticalDataTask(String sourceIndex, Client client, SdkClient sdkClient) {
92-
super(MLIndexInsightType.STATISTICAL_DATA, sourceIndex, client, sdkClient);
91+
public StatisticalDataTask(String sourceIndex, Client client, SdkClient sdkClient, String cmkRoleArn, String cmkAssumeRoleArn) {
92+
super(MLIndexInsightType.STATISTICAL_DATA, sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn);
9393
}
9494

9595
@Override

common/src/main/java/org/opensearch/ml/common/input/Constants.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,6 @@ public class Constants {
3737
public static final String AD_ANOMALY_SCORE_THRESHOLD = "anomalyScoreThreshold";
3838
public static final String AD_DATE_FORMAT = "dateFormat";
3939
public static final String TENANT_ID_HEADER = "x-tenant-id";
40+
public static final String CMK_ROLE_FIELD = "x-cmk-role";
41+
public static final String CMK_ASSUME_ROLE_FIELD = "x-cmk-assume-role";
4042
}

common/src/main/java/org/opensearch/ml/common/transport/indexInsight/MLIndexInsightGetRequest.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,30 @@ public class MLIndexInsightGetRequest extends ActionRequest {
2929
String indexName;
3030
MLIndexInsightType targetIndexInsight;
3131
String tenantId;
32-
33-
public MLIndexInsightGetRequest(String indexName, MLIndexInsightType targetIndexInsight, String tenantId) {
32+
String cmkRoleArn;
33+
String assumeRoleArn;
34+
35+
public MLIndexInsightGetRequest(
36+
String indexName,
37+
MLIndexInsightType targetIndexInsight,
38+
String tenantId,
39+
String cmkRoleArn,
40+
String assumeRoleArn
41+
) {
3442
this.indexName = indexName;
3543
this.targetIndexInsight = targetIndexInsight;
3644
this.tenantId = tenantId;
45+
this.cmkRoleArn = cmkRoleArn;
46+
this.assumeRoleArn = assumeRoleArn;
3747
}
3848

3949
public MLIndexInsightGetRequest(StreamInput in) throws IOException {
4050
super(in);
4151
this.indexName = in.readString();
4252
this.targetIndexInsight = MLIndexInsightType.fromString(in.readString());
4353
this.tenantId = in.readOptionalString();
54+
this.cmkRoleArn = in.readOptionalString();
55+
this.assumeRoleArn = in.readOptionalString();
4456
}
4557

4658
@Override
@@ -49,6 +61,8 @@ public void writeTo(StreamOutput out) throws IOException {
4961
out.writeString(this.indexName);
5062
out.writeString(this.targetIndexInsight.name());
5163
out.writeOptionalString(tenantId);
64+
out.writeOptionalString(cmkRoleArn);
65+
out.writeOptionalString(assumeRoleArn);
5266
}
5367

5468
@Override

common/src/main/java/org/opensearch/ml/common/utils/ToolUtils.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
import java.util.Map;
1414

1515
import org.apache.commons.text.StringSubstitutor;
16+
import org.opensearch.OpenSearchStatusException;
1617
import org.opensearch.common.xcontent.json.JsonXContent;
18+
import org.opensearch.core.rest.RestStatus;
1719
import org.opensearch.ml.common.agent.MLToolSpec;
1820
import org.opensearch.ml.common.output.model.ModelTensor;
1921
import org.opensearch.ml.common.output.model.ModelTensorOutput;
2022
import org.opensearch.ml.common.output.model.ModelTensors;
23+
import org.opensearch.rest.RestRequest;
2124

2225
import com.google.gson.reflect.TypeToken;
2326
import com.jayway.jsonpath.JsonPath;
@@ -249,4 +252,24 @@ public static ModelTensor convertOutputToModelTensor(Object output, String outpu
249252
}
250253
return modelTensor;
251254
}
255+
256+
/**
257+
* Fetch the attribute value from rest request's header.
258+
* @param targetKey The key we want to fetch
259+
* @param restRequest The rest request
260+
* @return The value in the rest request
261+
*/
262+
public static String getAttributeFromHeader(String targetKey, RestRequest restRequest) {
263+
Map<String, List<String>> headers = restRequest.getHeaders();
264+
if (headers == null) {
265+
throw new OpenSearchStatusException("Rest request headers can't be null", RestStatus.FORBIDDEN);
266+
}
267+
268+
List<String> resultList = headers.get(targetKey);
269+
if (resultList == null || resultList.isEmpty()) {
270+
return null;
271+
}
272+
return resultList.getFirst();
273+
}
274+
252275
}

common/src/test/java/org/opensearch/ml/common/indexInsight/AbstractIndexInsightTaskTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ public void testCallLLMWithAgent_Failure() {
509509
private static class TestIndexInsightTask extends AbstractIndexInsightTask {
510510

511511
TestIndexInsightTask(Client client, SdkClient sdkClient) {
512-
super(MLIndexInsightType.FIELD_DESCRIPTION, "test-index", client, sdkClient);
512+
super(MLIndexInsightType.FIELD_DESCRIPTION, "test-index", client, sdkClient, null, null);
513513
}
514514

515515
@Override
@@ -532,7 +532,7 @@ public IndexInsightTask createPrerequisiteTask(MLIndexInsightType prerequisiteTy
532532
private static class SimpleTestTask extends AbstractIndexInsightTask {
533533

534534
SimpleTestTask(Client client, SdkClient sdkClient) {
535-
super(MLIndexInsightType.STATISTICAL_DATA, "test-index", client, sdkClient);
535+
super(MLIndexInsightType.STATISTICAL_DATA, "test-index", client, sdkClient, null, null);
536536
}
537537

538538
@Override

common/src/test/java/org/opensearch/ml/common/indexInsight/FieldDescriptionTaskTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public class FieldDescriptionTaskTests {
5555
public void setUp() {
5656
client = mock(Client.class);
5757
sdkClient = mock(SdkClient.class);
58-
task = new FieldDescriptionTask("test-index", client, sdkClient);
58+
task = new FieldDescriptionTask("test-index", client, sdkClient, null, null);
5959
listener = mock(ActionListener.class);
6060
threadPool = mock(ThreadPool.class);
6161
Settings settings = Settings.builder().build();

common/src/test/java/org/opensearch/ml/common/indexInsight/LogRelatedIndexCheckTaskTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public class LogRelatedIndexCheckTaskTests {
4646
public void setUp() {
4747
client = mock(Client.class);
4848
sdkClient = mock(SdkClient.class);
49-
task = new LogRelatedIndexCheckTask("test-index", client, sdkClient);
49+
task = new LogRelatedIndexCheckTask("test-index", client, sdkClient, null, null);
5050
threadPool = mock(ThreadPool.class);
5151
Settings settings = Settings.builder().build();
5252
threadContext = new ThreadContext(settings);

0 commit comments

Comments
 (0)