diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tool/MLToolExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tool/MLToolExecutorTest.java index 58bd2199c8..59b91f3f65 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tool/MLToolExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tool/MLToolExecutorTest.java @@ -202,4 +202,26 @@ public void test_ImmutableEmptyParametersMap() { Output output = outputCaptor.getValue(); Assert.assertTrue(output instanceof ModelTensorOutput); } + + @Test + public void test_ToolExecutionFailsWithoutProperPermission() { + when(toolMLInput.getToolName()).thenReturn("TestTool"); + when(toolMLInput.getInputDataset()).thenReturn(inputDataSet); + when(inputDataSet.getParameters()).thenReturn(parameters); + when(toolFactory.create(any())).thenReturn(tool); + when(tool.validate(parameters)).thenReturn(true); + + Mockito.doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new SecurityException("no permissions for [indices:data/read/search] and User [name=test_user]")); + return null; + }).when(tool).run(Mockito.eq(parameters), any()); + + mlToolExecutor.execute(toolMLInput, actionListener); + + Mockito.verify(actionListener).onFailure(exceptionCaptor.capture()); + Exception exception = exceptionCaptor.getValue(); + Assert.assertTrue(exception instanceof SecurityException); + Assert.assertTrue(exception.getMessage().contains("no permissions")); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ReadFromScratchPadToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ReadFromScratchPadToolTests.java index b94960e107..e5f71e4663 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ReadFromScratchPadToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ReadFromScratchPadToolTests.java @@ -217,6 +217,21 @@ public void testRun_StringConversion_AddPersistentNote() { assertEquals("[\"existing\",\"new note\"]", parameters.get(ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY)); } + @Test + public void testRun_SecurityException() { + Map parameters = new HashMap<>(); + parameters.put(ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY, "[\"confidential data\"]"); + + SecurityException securityException = new SecurityException("no permissions for [indices:data/read/get] and User [name=test_user]"); + listener.onFailure(securityException); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + Exception exception = exceptionCaptor.getValue(); + assertTrue(exception instanceof SecurityException); + assertTrue(exception.getMessage().contains("no permissions")); + } + @Test public void testFactory() { ReadFromScratchPadTool.Factory factory = ReadFromScratchPadTool.Factory.getInstance(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/WriteToScratchPadToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/WriteToScratchPadToolTests.java index 35419d1872..081add09fc 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/WriteToScratchPadToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/WriteToScratchPadToolTests.java @@ -222,7 +222,23 @@ public void testRun_StringConversion_WithJsonArray() { ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); verify(listener).onResponse(captor.capture()); assertEquals("Wrote to scratchpad: new note", captor.getValue()); - assertEquals("[\"existing note\",\"new note\"]", parameters.get(WriteToScratchPadTool.SCRATCHPAD_NOTES_KEY)); + } + + @Test + public void testRun_SecurityException() { + Map parameters = new HashMap<>(); + parameters.put(WriteToScratchPadTool.NOTES_KEY, "confidential test data"); + + SecurityException securityException = new SecurityException( + "no permissions for [indices:data/write/index] and User [name=test_user]" + ); + listener.onFailure(securityException); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + Exception exception = exceptionCaptor.getValue(); + assertTrue(exception instanceof SecurityException); + assertTrue(exception.getMessage().contains("no permissions")); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index 26c41d5e49..1fee6e4c9c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -410,7 +410,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " \"context_size\": %d,\n" + " \"message_size\": %d,\n" + " \"timeout\": %d,\n" - + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"image\": {\"format\": \"%s\", \"%s\": \"%s\"}}] }]\n" + + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"type\": \"image_url\", \"image_url\": {\"%s\": \"%s\"}}] }]\n" + " }\n" + " }\n" + "}"; @@ -446,7 +446,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " \"context_size\": %d,\n" + " \"message_size\": %d,\n" + " \"timeout\": %d,\n" - + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"image\": {\"format\": \"%s\", \"%s\": \"%s\"}} , {\"document\": {\"format\": \"%s\", \"name\": \"%s\", \"data\": \"%s\"}}] }]\n" + + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"type\": \"image_url\", \"image_url\": {\"%s\": \"%s\"}} , {\"document\": {\"format\": \"%s\", \"name\": \"%s\", \"data\": \"%s\"}}] }]\n" + " }\n" + " }\n" + "}"; @@ -505,7 +505,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " \"context_size\": %d,\n" + " \"message_size\": %d,\n" + " \"timeout\": %d,\n" - + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"image\": {\"format\": \"%s\", \"%s\": \"%s\"}}] }]\n" + + " \"llm_messages\": [{ \"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"%s\"}, {\"type\": \"image_url\", \"image_url\": {\"%s\": \"%s\"}}] }]\n" + " }\n" + " }\n" + "}"; diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ListIndexToolIT.java b/plugin/src/test/java/org/opensearch/ml/tools/ListIndexToolIT.java index a38c6781c6..1dee2fba7e 100644 --- a/plugin/src/test/java/org/opensearch/ml/tools/ListIndexToolIT.java +++ b/plugin/src/test/java/org/opensearch/ml/tools/ListIndexToolIT.java @@ -14,9 +14,13 @@ import java.util.Objects; import org.apache.commons.lang3.StringUtils; +import org.apache.hc.core5.http.HttpHost; import org.junit.Before; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.client.RestClient; import org.opensearch.common.settings.Settings; +import org.opensearch.commons.rest.SecureRestClientBuilder; import org.opensearch.ml.engine.tools.ListIndexTool; import org.opensearch.ml.rest.RestBaseAgentToolsIT; import org.opensearch.ml.utils.TestHelper; @@ -37,6 +41,44 @@ public void setUpCluster() throws Exception { registerListIndexFlowAgent(); } + public void testListIndexWithNoPermissions() throws Exception { + if (!isHttps()) { + log.info("Skipping permission test as security is not enabled"); + return; + } + + String noPermissionUser = "no_permission_user"; + String password = "TestPassword123!"; + + try { + createUser(noPermissionUser, password, new ArrayList<>()); + + final RestClient noPermissionClient = new SecureRestClientBuilder( + getClusterHosts().toArray(new HttpHost[0]), + isHttps(), + noPermissionUser, + password + ).setSocketTimeout(60000).build(); + + try { + ResponseException exception = expectThrows(ResponseException.class, () -> { + TestHelper + .makeRequest(noPermissionClient, "POST", "/_plugins/_ml/agents/" + agentId + "/_execute", null, question, null); + }); + + String errorMessage = exception.getMessage().toLowerCase(); + assertTrue( + "Expected permission error, got: " + errorMessage, + errorMessage.contains("no permissions") || errorMessage.contains("forbidden") || errorMessage.contains("unauthorized") + ); + } finally { + noPermissionClient.close(); + } + } finally { + deleteUser(noPermissionUser); + } + } + private List createIndices(int count) throws IOException { List indices = new ArrayList<>(); for (int i = 0; i < count; i++) { diff --git a/plugin/src/test/java/org/opensearch/ml/tools/ScratchPadToolIT.java b/plugin/src/test/java/org/opensearch/ml/tools/ScratchPadToolIT.java new file mode 100644 index 0000000000..61aad92b4a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/tools/ScratchPadToolIT.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.tools; + +import static org.opensearch.ml.utils.TestHelper.makeRequest; + +import org.junit.Before; +import org.opensearch.ml.rest.MLCommonsRestTestCase; +import org.opensearch.test.OpenSearchIntegTestCase; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 3) +public class ScratchPadToolIT extends MLCommonsRestTestCase { + + @Before + public void setUp() throws Exception { + super.setUp(); + } + + public void testScratchpadSizeLimit() throws Exception { + String largeContent = "A".repeat(100 * 1024 * 1024); + String requestBody = String.format("{\"parameters\":{\"notes\":\"%s\"}}", largeContent); + + Exception exception = expectThrows(Exception.class, () -> { + makeRequest(client(), "POST", "/_plugins/_ml/tools/_execute/WriteToScratchPadTool", null, requestBody, null); + }); + + String errorMessage = exception.getMessage().toLowerCase(); + assertTrue( + "Expected HTTP content length error, got: " + errorMessage, + errorMessage.contains("content length") + || errorMessage.contains("too large") + || errorMessage.contains("entity too large") + || errorMessage.contains("413") + ); + } +}