Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions assistant-agent-start/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
<artifactId>python-community</artifactId>
<type>pom</type>
</dependency>

<!-- Test Dependencies -->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,21 @@
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand All @@ -47,6 +54,8 @@ public class MockBaiduWebSearchProvider implements SearchProvider {
private static final Logger logger = LoggerFactory.getLogger(MockBaiduWebSearchProvider.class);

private static final String BAIDU_SEARCH_URL = "https://www.baidu.com/s";
private static final String QIANFAN_API_URL = "https://qianfan.baidubce.com/v2/ai_search/web_search";
private static final ObjectMapper objectMapper = new ObjectMapper();
private static final int DEFAULT_TIMEOUT = 5000;

private final String apiKey;
Expand Down Expand Up @@ -154,31 +163,48 @@ private List<SearchResultItem> searchSimpleMode(String query, int topK) {
}

/**
* API模式搜索:使用百度搜索API
* TODO: 需要根据实际的百度搜索API文档实现
* API模式搜索:使用百度千帆智能搜索API
*/
private List<SearchResultItem> searchApiMode(String query, int topK) {
logger.info("MockBaiduWebSearchProvider#searchApiMode - reason=执行API模式搜索, query={}, topK={}", query, topK);

// TODO: 实现百度搜索API调用
// 这里暂时返回模拟数据
List<SearchResultItem> results = new ArrayList<>();

for (int i = 0; i < Math.min(topK, 3); i++) {
SearchResultItem item = new SearchResultItem();
item.setId(UUID.randomUUID().toString());
item.setSourceType(SearchSourceType.WEB);
item.setTitle("百度搜索结果 " + (i + 1) + ": " + query);
item.setSnippet("这是针对查询'" + query + "'的百度搜索结果摘要信息");
item.setContent("详细内容:这是从百度搜索获取的完整内容,包含了与查询相关的详细信息。");
item.setUri("https://www.example.com/result" + i);
item.setScore(0.95 - i * 0.1);
item.getMetadata().setSourceName(getName());
item.getMetadata().setLanguage("zh");
results.add(item);
}
try {
// 构建并发送HTTP请求
URL url = new URL(QIANFAN_API_URL);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setConnectTimeout(DEFAULT_TIMEOUT);
connection.setReadTimeout(DEFAULT_TIMEOUT);
connection.setRequestProperty("Authorization", "Bearer " + apiKey);
connection.setRequestProperty("Content-Type", "application/json");
connection.setDoOutput(true);

// 构建请求体(使用Jackson确保正确转义)
String requestBody = buildApiRequestBody(query, topK);
try (OutputStream os = connection.getOutputStream()) {
os.write(requestBody.getBytes(StandardCharsets.UTF_8));
}

logger.info("MockBaiduWebSearchProvider#searchApiMode - reason=API模式搜索完成, resultCount={}", results.size());
int responseCode = connection.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
String response = readHttpResponse(connection);
results = parseApiResponse(response);
logger.info("MockBaiduWebSearchProvider#searchApiMode - reason=API请求成功, resultCount={}", results.size());
} else {
String errorResponse = readErrorResponse(connection);
logger.error("MockBaiduWebSearchProvider#searchApiMode - reason=API请求失败, responseCode={}, error={}",
responseCode, errorResponse);
// 降级:返回兜底示例数据
results = createFallbackResults(query, topK);
}
connection.disconnect();
} catch (Exception e) {
logger.error("MockBaiduWebSearchProvider#searchApiMode - reason=API模式搜索异常, error={}", e.getMessage(), e);
// 降级:返回兜底示例数据
results = createFallbackResults(query, topK);
}

return results;
}
Expand Down Expand Up @@ -298,6 +324,130 @@ private String cleanHtmlText(String html) {
return text.trim();
}

/**
* 创建兜底示例数据(API调用失败时使用)
*/
private List<SearchResultItem> createFallbackResults(String query, int topK) {
logger.info("MockBaiduWebSearchProvider#createFallbackResults - reason=返回兜底示例数据, query={}", query);

List<SearchResultItem> results = new ArrayList<>();
int count = Math.min(topK, 3);

for (int i = 0; i < count; i++) {
SearchResultItem item = new SearchResultItem();
item.setId(UUID.randomUUID().toString());
item.setSourceType(SearchSourceType.WEB);
item.setTitle("[示例] 百度搜索结果 " + (i + 1) + ": " + query);
item.setSnippet("这是针对查询'" + query + "'的示例搜索结果摘要信息(API调用失败,返回兜底数据)");
item.setContent("示例内容:这是兜底数据,实际API调用失败。请检查API Key配置或网络连接。");
item.setUri("https://www.baidu.com/s?wd=" + query);
item.setScore(0.95 - i * 0.1);
item.getMetadata().setSourceName(getName());
item.getMetadata().setLanguage("zh");
results.add(item);
}

return results;
}

/**
* 构建API请求体
*/
private String buildApiRequestBody(String query, int topK) throws Exception {
// 限制 topK 最大值为50
int limitedTopK = Math.min(topK, 10);

Map<String, Object> request = new HashMap<>();
request.put("messages", List.of(Map.of("role", "user", "content", query)));
request.put("stream", false);
request.put("resource_type_filter", List.of(Map.of("type", "web", "top_k", limitedTopK)));

return objectMapper.writeValueAsString(request);
}

/**
* 读取HTTP响应
*/
private String readHttpResponse(HttpURLConnection connection) throws IOException {
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
StringBuilder response = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
response.append(line);
}
return response.toString();
}
}

/**
* 读取错误响应
*/
private String readErrorResponse(HttpURLConnection connection) {
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(connection.getErrorStream(), StandardCharsets.UTF_8))) {
StringBuilder response = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
response.append(line);
}
return response.toString();
} catch (Exception e) {
return "无法读取错误信息";
}
}

/**
* 解析API响应
*/
private List<SearchResultItem> parseApiResponse(String json) {
List<SearchResultItem> results = new ArrayList<>();

try {
JsonNode root = objectMapper.readTree(json);
JsonNode references = root.get("references");

if (references != null && references.isArray()) {
for (JsonNode ref : references) {
// 只处理 web 类型的结果
String type = getTextValue(ref, "type");
if (!"web".equals(type)) {
continue;
}

SearchResultItem item = new SearchResultItem();
item.setId(String.valueOf(ref.get("id").asInt()));
item.setSourceType(SearchSourceType.WEB);
item.setTitle(getTextValue(ref, "title"));
item.setSnippet(getTextValue(ref, "snippet"));
item.setContent(getTextValue(ref, "content"));
item.setUri(getTextValue(ref, "url"));

// 使用 rerank_score 作为相关度评分
item.setScore(ref.has("rerank_score") ? ref.get("rerank_score").asDouble() : 0.0);

// 设置基本元数据
item.getMetadata().setSourceName(getName());
item.getMetadata().setLanguage("zh");

results.add(item);
}
}
} catch (Exception e) {
logger.error("MockBaiduWebSearchProvider#parseApiResponse - reason=解析响应失败, error={}", e.getMessage(), e);
}

return results;
}

/**
* 安全获取文本值
*/
private String getTextValue(JsonNode node, String field) {
JsonNode fieldNode = node.get(field);
return (fieldNode != null && !fieldNode.isNull()) ? fieldNode.asText() : "";
}

@Override
public String getName() {
return "MockBaiduWebSearchProvider";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.assistant.agent.start.search;

import com.alibaba.assistant.agent.extension.search.model.SearchRequest;
import com.alibaba.assistant.agent.extension.search.model.SearchResultItem;
import com.alibaba.assistant.agent.extension.search.model.SearchSourceType;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

/**
* MockBaiduWebSearchProvider 单元测试
*
* @author Assistant Agent Team
*/
class MockBaiduWebSearchProviderTest {

@Test
@DisplayName("测试简单模式构造 - null API Key应使用简单模式")
void testConstructor_withNullApiKey_shouldUseSimpleMode() {
MockBaiduWebSearchProvider p = new MockBaiduWebSearchProvider(null);
SearchRequest request = new SearchRequest("Spring AI Alibaba");
request.setTopK(3);

List<SearchResultItem> results = p.search(request);

assertNotNull(results);
assertFalse(results.isEmpty());
}

@Test
@DisplayName("集成测试 - 使用真实API Key")
@EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+")
void testRealApiIntegration() {
String realApiKey = System.getenv("QIANFAN_API_KEY");
MockBaiduWebSearchProvider realProvider = new MockBaiduWebSearchProvider(realApiKey);

SearchRequest request = new SearchRequest("Spring AI Alibaba");
request.setTopK(5);

List<SearchResultItem> results = realProvider.search(request);

assertNotNull(results);
assertFalse(results.isEmpty(), "真实API应返回搜索结果");

System.out.println("=== 真实API返回结果 ===");
System.out.println("结果数量: " + results.size());

for (SearchResultItem item : results) {
System.out.println("---");
System.out.println("标题: " + item.getTitle());
System.out.println("URL: " + item.getUri());
System.out.println("评分: " + item.getScore());
String snippet = item.getSnippet();
if (snippet != null && snippet.length() > 100) {
snippet = snippet.substring(0, 100) + "...";
}
System.out.println("摘要: " + snippet);
}

// 验证第一个结果的完整性
SearchResultItem first = results.get(0);
assertNotNull(first.getId());
assertEquals(SearchSourceType.WEB, first.getSourceType());
assertNotNull(first.getTitle());
assertFalse(first.getTitle().isEmpty());
}
}