Skip to content

Commit f934f16

Browse files
[FEATURE] support tools enhanced by AOP (#80)
Issues link: langchain4j/langchain4j#2113
1 parent 10c6cdc commit f934f16

File tree

7 files changed

+188
-2
lines changed

7 files changed

+188
-2
lines changed

Diff for: langchain4j-spring-boot-starter/pom.xml

+7
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@
5353
<scope>test</scope>
5454
</dependency>
5555

56+
<dependency>
57+
<groupId>org.springframework.boot</groupId>
58+
<artifactId>spring-boot-starter-aop</artifactId>
59+
<version>${spring.boot.version}</version>
60+
<scope>test</scope>
61+
</dependency>
62+
5663
<dependency>
5764
<groupId>dev.langchain4j</groupId>
5865
<artifactId>langchain4j-core</artifactId>

Diff for: langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServiceFactory.java

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package dev.langchain4j.service.spring;
22

3+
import dev.langchain4j.agent.tool.Tool;
4+
import dev.langchain4j.agent.tool.ToolSpecification;
35
import dev.langchain4j.memory.ChatMemory;
46
import dev.langchain4j.memory.chat.ChatMemoryProvider;
57
import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -8,11 +10,20 @@
810
import dev.langchain4j.rag.RetrievalAugmentor;
911
import dev.langchain4j.rag.content.retriever.ContentRetriever;
1012
import dev.langchain4j.service.AiServices;
13+
import dev.langchain4j.service.tool.DefaultToolExecutor;
14+
import dev.langchain4j.service.tool.ToolExecutor;
1115
import org.springframework.beans.factory.FactoryBean;
1216

17+
import java.lang.reflect.Method;
18+
import java.util.Arrays;
19+
import java.util.HashMap;
1320
import java.util.List;
21+
import java.util.Map;
1422

23+
import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;
1524
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
25+
import static org.springframework.aop.framework.AopProxyUtils.ultimateTargetClass;
26+
import static org.springframework.aop.support.AopUtils.isAopProxy;
1627

1728
class AiServiceFactory implements FactoryBean<Object> {
1829

@@ -94,7 +105,13 @@ public Object getObject() {
94105
}
95106

96107
if (!isNullOrEmpty(tools)) {
97-
builder = builder.tools(tools);
108+
for (Object tool : tools) {
109+
if (isAopProxy(tool)) {
110+
builder = builder.tools(aopEnhancedTools(tool));
111+
} else {
112+
builder = builder.tools(tool);
113+
}
114+
}
98115
}
99116

100117
return builder.build();
@@ -120,4 +137,21 @@ public boolean isSingleton() {
120137
* (such as java.io.Closeable.close()) will not be called automatically.
121138
* Instead, a FactoryBean should implement DisposableBean and delegate any such close call to the underlying object.
122139
*/
140+
141+
private Map<ToolSpecification, ToolExecutor> aopEnhancedTools(Object enhancedTool) {
142+
Map<ToolSpecification, ToolExecutor> toolExecutors = new HashMap<>();
143+
Class<?> originalToolClass = ultimateTargetClass(enhancedTool);
144+
for (Method originalToolMethod : originalToolClass.getDeclaredMethods()) {
145+
if (originalToolMethod.isAnnotationPresent(Tool.class)) {
146+
Arrays.stream(enhancedTool.getClass().getDeclaredMethods())
147+
.filter(m -> m.getName().equals(originalToolMethod.getName()))
148+
.findFirst()
149+
.ifPresent(enhancedMethod -> {
150+
ToolSpecification toolSpecification = toolSpecificationFrom(originalToolMethod);
151+
toolExecutors.put(toolSpecification, new DefaultToolExecutor(enhancedTool, enhancedMethod));
152+
});
153+
}
154+
}
155+
return toolExecutors;
156+
}
123157
}

Diff for: langchain4j-spring-boot-starter/src/main/java/dev/langchain4j/service/spring/AiServicesAutoConfig.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
4949
Set<String> tools = new HashSet<>();
5050
for (String beanName : beanFactory.getBeanDefinitionNames()) {
5151
try {
52-
Class<?> beanClass = Class.forName(beanFactory.getBeanDefinition(beanName).getBeanClassName());
52+
String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName();
53+
if (beanClassName == null) {
54+
continue;
55+
}
56+
Class<?> beanClass = Class.forName(beanClassName);
5357
for (Method beanMethod : beanClass.getDeclaredMethods()) {
5458
if (beanMethod.isAnnotationPresent(Tool.class)) {
5559
tools.add(beanName);

Diff for: langchain4j-spring-boot-starter/src/test/java/dev/langchain4j/service/spring/mode/automatic/withTools/AiServicesAutoConfigIT.java

+48
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
package dev.langchain4j.service.spring.mode.automatic.withTools;
22

33
import dev.langchain4j.service.spring.AiServicesAutoConfig;
4+
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect;
45
import org.junit.jupiter.api.Test;
56
import org.springframework.boot.autoconfigure.AutoConfigurations;
67
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
78

89
import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY;
10+
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY;
11+
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION;
12+
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME;
13+
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION;
914
import static dev.langchain4j.service.spring.mode.automatic.withTools.PackagePrivateTools.CURRENT_TIME;
1015
import static dev.langchain4j.service.spring.mode.automatic.withTools.PublicTools.CURRENT_DATE;
1116
import static org.assertj.core.api.Assertions.assertThat;
17+
import static org.junit.jupiter.api.Assertions.assertEquals;
18+
import static org.junit.jupiter.api.Assertions.assertFalse;
19+
import static org.junit.jupiter.api.Assertions.assertTrue;
1220

1321
class AiServicesAutoConfigIT {
1422

@@ -61,6 +69,46 @@ void should_create_AI_service_with_tool_that_is_package_private_method_in_packag
6169
});
6270
}
6371

72+
@Test
73+
void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() {
74+
contextRunner
75+
.withPropertyValues(
76+
"langchain4j.open-ai.chat-model.api-key=" + OPENAI_API_KEY,
77+
"langchain4j.open-ai.chat-model.temperature=0.0",
78+
"langchain4j.open-ai.chat-model.log-requests=true",
79+
"langchain4j.open-ai.chat-model.log-responses=true"
80+
)
81+
.withUserConfiguration(AiServiceWithToolsApplication.class)
82+
.run(context -> {
83+
84+
// given
85+
AiServiceWithTools aiService = context.getBean(AiServiceWithTools.class);
86+
87+
// when
88+
String answer = aiService.chat("Which package is the @ToolObserver annotation located in? " +
89+
"And what is the key of the @ToolObserver annotation?" +
90+
"And What is the current time?");
91+
92+
System.out.println("Answer: " + answer);
93+
94+
// then should use AopEnhancedTools.getAspectPackage()
95+
// & AopEnhancedTools.getToolObserverKey()
96+
// & PackagePrivateTools.getCurrentTime()
97+
assertThat(answer).contains(TOOL_OBSERVER_PACKAGE_NAME);
98+
assertThat(answer).contains(TOOL_OBSERVER_KEY);
99+
assertThat(answer).contains(String.valueOf(CURRENT_TIME.getMinute()));
100+
101+
// and AOP aspect should be called
102+
// & only for getToolObserverKey() which is annotated with @ToolObserver
103+
ToolObserverAspect aspect = context.getBean(ToolObserverAspect.class);
104+
assertTrue(aspect.aspectHasBeenCalled());
105+
106+
assertEquals(1, aspect.getObservedTools().size());
107+
assertTrue(aspect.getObservedTools().contains(TOOL_OBSERVER_KEY_NAME_DESCRIPTION));
108+
assertFalse(aspect.getObservedTools().contains(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION));
109+
});
110+
}
111+
64112
// TODO tools which are not @Beans?
65113
// TODO negative cases
66114
// TODO no @AiServices in app, just models
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package dev.langchain4j.service.spring.mode.automatic.withTools;
2+
3+
import dev.langchain4j.agent.tool.Tool;
4+
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserver;
5+
import org.springframework.stereotype.Component;
6+
7+
@Component
8+
public class AopEnhancedTools {
9+
10+
public static final String TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION =
11+
"Find the package directory where @ToolObserver is located.";
12+
public static final String TOOL_OBSERVER_PACKAGE_NAME = ToolObserver.class.getPackageName();
13+
14+
public static final String TOOL_OBSERVER_KEY_NAME_DESCRIPTION =
15+
"Find the key name of @ToolObserver";
16+
public static final String TOOL_OBSERVER_KEY = "AOP_ENHANCED_TOOLS_SUPPORT_@_1122";
17+
18+
@Tool(TOOL_OBSERVER_PACKAGE_NAME_DESCRIPTION)
19+
public String getToolObserverPackageName() {
20+
return TOOL_OBSERVER_PACKAGE_NAME;
21+
}
22+
23+
@ToolObserver(key = TOOL_OBSERVER_KEY)
24+
@Tool(TOOL_OBSERVER_KEY_NAME_DESCRIPTION)
25+
public String getToolObserverKey() {
26+
return TOOL_OBSERVER_KEY;
27+
}
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package dev.langchain4j.service.spring.mode.automatic.withTools.aop;
2+
3+
import java.lang.annotation.ElementType;
4+
import java.lang.annotation.Retention;
5+
import java.lang.annotation.RetentionPolicy;
6+
import java.lang.annotation.Target;
7+
8+
@Target({ElementType.METHOD})
9+
@Retention(RetentionPolicy.RUNTIME)
10+
public @interface ToolObserver {
11+
12+
/**
13+
* key just for example
14+
*
15+
* @return the key
16+
*/
17+
String key();
18+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package dev.langchain4j.service.spring.mode.automatic.withTools.aop;
2+
3+
import dev.langchain4j.agent.tool.Tool;
4+
import org.aspectj.lang.ProceedingJoinPoint;
5+
import org.aspectj.lang.annotation.Around;
6+
import org.aspectj.lang.annotation.Aspect;
7+
import org.aspectj.lang.reflect.MethodSignature;
8+
import org.springframework.stereotype.Component;
9+
10+
import java.util.ArrayList;
11+
import java.util.Arrays;
12+
import java.util.List;
13+
14+
@Aspect
15+
@Component
16+
public class ToolObserverAspect {
17+
18+
private final List<String> observedTools = new ArrayList<>();
19+
20+
@Around("@annotation(toolObserver)")
21+
public Object around(ProceedingJoinPoint joinPoint, ToolObserver toolObserver) throws Throwable {
22+
var signature = (MethodSignature) joinPoint.getSignature();
23+
var method = signature.getMethod();
24+
String methodName = method.getName();
25+
if (method.isAnnotationPresent(Tool.class)) {
26+
Tool toolAnnotation = method.getAnnotation(Tool.class);
27+
observedTools.addAll(Arrays.asList(toolAnnotation.value()));
28+
System.out.printf("Found @Tool %s for method: %s%n%n", Arrays.toString(toolAnnotation.value()), methodName);
29+
}
30+
Object result = joinPoint.proceed();
31+
System.out.printf(" | key: %s%n | Method name: %s%n | Method arguments: %s%n | Return type: %s%n | Method return value: %s%n%n",
32+
toolObserver.key(),
33+
methodName,
34+
Arrays.toString(joinPoint.getArgs()),
35+
method.getReturnType().getName(),
36+
result);
37+
return result;
38+
}
39+
40+
public boolean aspectHasBeenCalled() {
41+
return !observedTools.isEmpty();
42+
}
43+
44+
public List<String> getObservedTools() {
45+
return observedTools;
46+
}
47+
}

0 commit comments

Comments
 (0)