Skip to content

Commit

Permalink
BIGTOP-4344: Add ut cases for tools classes in server module (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
xianrenzw authored Feb 4, 2025
1 parent f12ae98 commit e10701d
Show file tree
Hide file tree
Showing 4 changed files with 570 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.server.tools.functions;

import org.apache.bigtop.manager.server.model.vo.ClusterVO;
import org.apache.bigtop.manager.server.service.ClusterService;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.tool.ToolExecutor;

import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class ClusterFunctionsTest {

@Mock
private ClusterService clusterService;

@InjectMocks
private ClusterFunctions clusterFunctions;

private ClusterVO testCluster;

@BeforeEach
void setUp() {
testCluster = new ClusterVO();
testCluster.setId(1L);
testCluster.setName("test-cluster");
}

@Test
void testListCluster() {
// Mock clusterService response
when(clusterService.list()).thenReturn(Collections.singletonList(testCluster));

// Get the tool specification and executor
Map<ToolSpecification, ToolExecutor> tools = clusterFunctions.listCluster();
assertEquals(1, tools.size());

ToolSpecification spec = tools.keySet().iterator().next();
ToolExecutor executor = tools.get(spec);

// Execute the tool
String result = executor.execute(ToolExecutionRequest.builder().build(), "memoryId");

// Verify results
assertTrue(result.contains("test-cluster"));
verify(clusterService, times(1)).list();
}

@Test
void testGetClusterById() {
// Mock clusterService response
when(clusterService.get(1L)).thenReturn(testCluster);

// Get the tool specification and executor
Map<ToolSpecification, ToolExecutor> tools = clusterFunctions.getClusterById();
assertEquals(1, tools.size());

ToolSpecification spec = tools.keySet().iterator().next();
ToolExecutor executor = tools.get(spec);

// Build request with arguments
String arguments = "{\"clusterId\": 1}";
ToolExecutionRequest request =
ToolExecutionRequest.builder().arguments(arguments).build();

// Execute the tool
String result = executor.execute(request, "memoryId");

// Verify results
assertTrue(result.contains("test-cluster"));
verify(clusterService, times(1)).get(1L);
}

@Test
void testGetClusterByIdWhenNotExists() {
// Mock clusterService response
when(clusterService.get(999L)).thenReturn(null);

// Get the tool specification and executor
Map<ToolSpecification, ToolExecutor> tools = clusterFunctions.getClusterById();
ToolExecutor executor = tools.values().iterator().next();

// Build request with arguments
String arguments = "{\"clusterId\": 999}";
ToolExecutionRequest request =
ToolExecutionRequest.builder().arguments(arguments).build();

// Execute the tool
String result = executor.execute(request, "memoryId");

// Verify results
assertEquals("Cluster not found", result);
}

@Test
void testGetClusterByName() {
// Mock clusterService response
when(clusterService.list()).thenReturn(Collections.singletonList(testCluster));

// Get the tool specification and executor
Map<ToolSpecification, ToolExecutor> tools = clusterFunctions.getClusterByName();
ToolExecutor executor = tools.values().iterator().next();

// Build request with arguments
String arguments = "{\"clusterName\": \"test-cluster\"}";
ToolExecutionRequest request =
ToolExecutionRequest.builder().arguments(arguments).build();

// Execute the tool
String result = executor.execute(request, "memoryId");

// Verify results
assertTrue(result.contains("test-cluster"));
verify(clusterService, times(1)).list();
}

@Test
void testGetClusterByNameWhenNotExists() {
// Mock clusterService response
when(clusterService.list()).thenReturn(Collections.singletonList(testCluster));

// Get the tool specification and executor
Map<ToolSpecification, ToolExecutor> tools = clusterFunctions.getClusterByName();
ToolExecutor executor = tools.values().iterator().next();

// Build request with arguments
String arguments = "{\"clusterName\": \"non-existent\"}";
ToolExecutionRequest request =
ToolExecutionRequest.builder().arguments(arguments).build();

// Execute the tool
String result = executor.execute(request, "memoryId");

// Verify results
assertEquals("Cluster not found", result);
}

@Test
void testGetAllFunctions() {
Map<ToolSpecification, ToolExecutor> functions = clusterFunctions.getAllFunctions();
assertEquals(3, functions.size());

List<String> expectedToolNames = List.of("listCluster", "getClusterById", "getClusterByName");
assertTrue(functions.keySet().stream().map(ToolSpecification::name).allMatch(expectedToolNames::contains));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.server.tools.functions;

import org.apache.bigtop.manager.dao.query.HostQuery;
import org.apache.bigtop.manager.server.model.vo.HostVO;
import org.apache.bigtop.manager.server.model.vo.PageVO;
import org.apache.bigtop.manager.server.service.HostService;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.tool.ToolExecutor;

import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class HostFunctionsTest {

@Mock
private HostService hostService;

@InjectMocks
private HostFunctions hostFunctions;

private HostVO testHost;
private PageVO<HostVO> testPage;

@BeforeEach
void setUp() {
testHost = new HostVO();
testHost.setId(1L);
testHost.setHostname("test-host");

testPage = new PageVO<>();
testPage.setContent(List.of(testHost));
testPage.setTotal(1L);
}

@Test
void testGetHostByIdToolSpecification() {
Map<ToolSpecification, ToolExecutor> tools = hostFunctions.getHostById();
assertEquals(1, tools.size());

ToolSpecification spec = tools.keySet().iterator().next();
Map<String, Map<String, Object>> params = spec.parameters().properties();

assertEquals(1, params.size());
assertTrue(params.containsKey("hostId"));
assertEquals("number", params.get("hostId").get("type"));
assertEquals("host id", params.get("hostId").get("description"));
}

@Test
void testGetHostByIdExecutorFound() throws Exception {
when(hostService.get(1L)).thenReturn(testHost);

Map<ToolSpecification, ToolExecutor> tools = hostFunctions.getHostById();
ToolExecutor executor = tools.values().iterator().next();

String arguments = "{\"hostId\": 1}";
String result = executor.execute(
ToolExecutionRequest.builder().arguments(arguments).build(), null);

// Use system-independent newline character regex
String expectedPattern = ".*\"hostname\"\\s*:\\s*\"test-host\".*";
assertTrue(
result.replaceAll("\\R", System.lineSeparator()).matches("(?s)" + expectedPattern),
"Hostname should match with any line separators");
}

@Test
void testGetHostByIdExecutorNotFound() {
when(hostService.get(anyLong())).thenReturn(null);

Map<ToolSpecification, ToolExecutor> tools = hostFunctions.getHostById();
ToolExecutor executor = tools.values().iterator().next();

String arguments = "{\"hostId\": 999}";
String result = executor.execute(
ToolExecutionRequest.builder().arguments(arguments).build(), null);

assertEquals("Host not found", result);
}

@Test
void testGetHostByNameToolSpecification() {
Map<ToolSpecification, ToolExecutor> tools = hostFunctions.getHostByName();
assertEquals(1, tools.size());

ToolSpecification spec = tools.keySet().iterator().next();
assertEquals("getHostByName", spec.name());
assertEquals("Get host information based on cluster name", spec.description());
Map<String, Map<String, Object>> params = spec.parameters().properties();
assertEquals(1, params.size());
assertTrue(params.containsKey("hostName"));
assertEquals("string", params.get("hostName").get("type"));
}

@Test
void testGetHostByNameExecutor() {
HostQuery query = new HostQuery();
query.setHostname("test-host");
when(hostService.list(query)).thenReturn(testPage);

Map<ToolSpecification, ToolExecutor> tools = hostFunctions.getHostByName();
ToolExecutor executor = tools.values().iterator().next();

String arguments = "{\"hostName\":\"test-host\"}";
String result = executor.execute(
ToolExecutionRequest.builder().arguments(arguments).build(), null);

// System-independent matching pattern
String totalPattern = "(?s).*\"total\"\\s*:\\s*1.*";
String hostPattern = "(?s).*\"hostname\"\\s*:\\s*\"test-host\".*";
assertTrue(result.matches(totalPattern), "Should contain total=1");
assertTrue(result.matches(hostPattern), "Should contain hostname=test-host");
}

@Test
void testGetAllFunctions() {
Map<ToolSpecification, ToolExecutor> functions = hostFunctions.getAllFunctions();
assertEquals(2, functions.size());
assertTrue(functions.keySet().stream().anyMatch(s -> s.name().equals("getHostById")));
assertTrue(functions.keySet().stream().anyMatch(s -> s.name().equals("getHostByName")));
}
}
Loading

0 comments on commit e10701d

Please sign in to comment.