From 698a586dba9a92153e272cf2bc1e86556726e29c Mon Sep 17 00:00:00 2001 From: wangx Date: Sun, 18 Jan 2026 15:42:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(search):=20=E4=BD=BF=E7=94=A8=E7=99=BE?= =?UTF-8?q?=E5=BA=A6=E5=8D=83=E5=B8=86=E6=99=BA=E8=83=BD=E6=90=9C=E7=B4=A2?= =?UTF-8?q?API=E6=9B=BF=E4=BB=A3=E6=A8=A1=E6=8B=9F=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增基于API调用的百度千帆智能搜索实现 - 实现POST请求构建与JSON请求体封装,保证正确传递参数 - 解析API响应JSON,提取web类型搜索结果并转换为SearchResultItem对象 - 添加失败时返回兜底示例数据的降级方案,保证接口稳定性 - 增加HTTP响应与错误响应读取工具方法,改善代码复用 - 添加单元测试验证无API Key时使用简单模式,且支持环境变量配置进行真实API集成测试 - 在pom.xml中加入JUnit Jupiter测试依赖支持单元测试运行 --- assistant-agent-start/pom.xml | 7 + .../search/MockBaiduWebSearchProvider.java | 186 ++++++++++++++++-- .../MockBaiduWebSearchProviderTest.java | 87 ++++++++ 3 files changed, 262 insertions(+), 18 deletions(-) create mode 100644 assistant-agent-start/src/test/java/com/alibaba/assistant/agent/start/search/MockBaiduWebSearchProviderTest.java diff --git a/assistant-agent-start/pom.xml b/assistant-agent-start/pom.xml index 0d3511b..eb86883 100644 --- a/assistant-agent-start/pom.xml +++ b/assistant-agent-start/pom.xml @@ -43,6 +43,13 @@ python-community pom + + + + org.junit.jupiter + junit-jupiter + test + diff --git a/assistant-agent-start/src/main/java/com/alibaba/assistant/agent/start/search/MockBaiduWebSearchProvider.java b/assistant-agent-start/src/main/java/com/alibaba/assistant/agent/start/search/MockBaiduWebSearchProvider.java index f37f1f4..5d56e5e 100644 --- a/assistant-agent-start/src/main/java/com/alibaba/assistant/agent/start/search/MockBaiduWebSearchProvider.java +++ b/assistant-agent-start/src/main/java/com/alibaba/assistant/agent/start/search/MockBaiduWebSearchProvider.java @@ -23,14 +23,21 @@ import org.slf4j.LoggerFactory; import org.springframework.stereotype.Component; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + import java.io.BufferedReader; +import java.io.IOException; import java.io.InputStreamReader; +import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.URL; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.UUID; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -47,6 +54,8 @@ public class MockBaiduWebSearchProvider implements SearchProvider { private static final Logger logger = LoggerFactory.getLogger(MockBaiduWebSearchProvider.class); private static final String BAIDU_SEARCH_URL = "https://www.baidu.com/s"; + private static final String QIANFAN_API_URL = "https://qianfan.baidubce.com/v2/ai_search/web_search"; + private static final ObjectMapper objectMapper = new ObjectMapper(); private static final int DEFAULT_TIMEOUT = 5000; private final String apiKey; @@ -154,31 +163,48 @@ private List searchSimpleMode(String query, int topK) { } /** - * API模式搜索:使用百度搜索API - * TODO: 需要根据实际的百度搜索API文档实现 + * API模式搜索:使用百度千帆智能搜索API */ private List searchApiMode(String query, int topK) { logger.info("MockBaiduWebSearchProvider#searchApiMode - reason=执行API模式搜索, query={}, topK={}", query, topK); - // TODO: 实现百度搜索API调用 - // 这里暂时返回模拟数据 List results = new ArrayList<>(); - for (int i = 0; i < Math.min(topK, 3); i++) { - SearchResultItem item = new SearchResultItem(); - item.setId(UUID.randomUUID().toString()); - item.setSourceType(SearchSourceType.WEB); - item.setTitle("百度搜索结果 " + (i + 1) + ": " + query); - item.setSnippet("这是针对查询'" + query + "'的百度搜索结果摘要信息"); - item.setContent("详细内容:这是从百度搜索获取的完整内容,包含了与查询相关的详细信息。"); - item.setUri("https://www.example.com/result" + i); - item.setScore(0.95 - i * 0.1); - item.getMetadata().setSourceName(getName()); - item.getMetadata().setLanguage("zh"); - results.add(item); - } + try { + // 构建并发送HTTP请求 + URL url = new URL(QIANFAN_API_URL); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("POST"); + connection.setConnectTimeout(DEFAULT_TIMEOUT); + connection.setReadTimeout(DEFAULT_TIMEOUT); + connection.setRequestProperty("Authorization", "Bearer " + apiKey); + connection.setRequestProperty("Content-Type", "application/json"); + connection.setDoOutput(true); + + // 构建请求体(使用Jackson确保正确转义) + String requestBody = buildApiRequestBody(query, topK); + try (OutputStream os = connection.getOutputStream()) { + os.write(requestBody.getBytes(StandardCharsets.UTF_8)); + } - logger.info("MockBaiduWebSearchProvider#searchApiMode - reason=API模式搜索完成, resultCount={}", results.size()); + int responseCode = connection.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_OK) { + String response = readHttpResponse(connection); + results = parseApiResponse(response); + logger.info("MockBaiduWebSearchProvider#searchApiMode - reason=API请求成功, resultCount={}", results.size()); + } else { + String errorResponse = readErrorResponse(connection); + logger.error("MockBaiduWebSearchProvider#searchApiMode - reason=API请求失败, responseCode={}, error={}", + responseCode, errorResponse); + // 降级:返回兜底示例数据 + results = createFallbackResults(query, topK); + } + connection.disconnect(); + } catch (Exception e) { + logger.error("MockBaiduWebSearchProvider#searchApiMode - reason=API模式搜索异常, error={}", e.getMessage(), e); + // 降级:返回兜底示例数据 + results = createFallbackResults(query, topK); + } return results; } @@ -298,6 +324,130 @@ private String cleanHtmlText(String html) { return text.trim(); } + /** + * 创建兜底示例数据(API调用失败时使用) + */ + private List createFallbackResults(String query, int topK) { + logger.info("MockBaiduWebSearchProvider#createFallbackResults - reason=返回兜底示例数据, query={}", query); + + List results = new ArrayList<>(); + int count = Math.min(topK, 3); + + for (int i = 0; i < count; i++) { + SearchResultItem item = new SearchResultItem(); + item.setId(UUID.randomUUID().toString()); + item.setSourceType(SearchSourceType.WEB); + item.setTitle("[示例] 百度搜索结果 " + (i + 1) + ": " + query); + item.setSnippet("这是针对查询'" + query + "'的示例搜索结果摘要信息(API调用失败,返回兜底数据)"); + item.setContent("示例内容:这是兜底数据,实际API调用失败。请检查API Key配置或网络连接。"); + item.setUri("https://www.baidu.com/s?wd=" + query); + item.setScore(0.95 - i * 0.1); + item.getMetadata().setSourceName(getName()); + item.getMetadata().setLanguage("zh"); + results.add(item); + } + + return results; + } + + /** + * 构建API请求体 + */ + private String buildApiRequestBody(String query, int topK) throws Exception { + // 限制 topK 最大值为50 + int limitedTopK = Math.min(topK, 10); + + Map request = new HashMap<>(); + request.put("messages", List.of(Map.of("role", "user", "content", query))); + request.put("stream", false); + request.put("resource_type_filter", List.of(Map.of("type", "web", "top_k", limitedTopK))); + + return objectMapper.writeValueAsString(request); + } + + /** + * 读取HTTP响应 + */ + private String readHttpResponse(HttpURLConnection connection) throws IOException { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) { + StringBuilder response = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + response.append(line); + } + return response.toString(); + } + } + + /** + * 读取错误响应 + */ + private String readErrorResponse(HttpURLConnection connection) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(connection.getErrorStream(), StandardCharsets.UTF_8))) { + StringBuilder response = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + response.append(line); + } + return response.toString(); + } catch (Exception e) { + return "无法读取错误信息"; + } + } + + /** + * 解析API响应 + */ + private List parseApiResponse(String json) { + List results = new ArrayList<>(); + + try { + JsonNode root = objectMapper.readTree(json); + JsonNode references = root.get("references"); + + if (references != null && references.isArray()) { + for (JsonNode ref : references) { + // 只处理 web 类型的结果 + String type = getTextValue(ref, "type"); + if (!"web".equals(type)) { + continue; + } + + SearchResultItem item = new SearchResultItem(); + item.setId(String.valueOf(ref.get("id").asInt())); + item.setSourceType(SearchSourceType.WEB); + item.setTitle(getTextValue(ref, "title")); + item.setSnippet(getTextValue(ref, "snippet")); + item.setContent(getTextValue(ref, "content")); + item.setUri(getTextValue(ref, "url")); + + // 使用 rerank_score 作为相关度评分 + item.setScore(ref.has("rerank_score") ? ref.get("rerank_score").asDouble() : 0.0); + + // 设置基本元数据 + item.getMetadata().setSourceName(getName()); + item.getMetadata().setLanguage("zh"); + + results.add(item); + } + } + } catch (Exception e) { + logger.error("MockBaiduWebSearchProvider#parseApiResponse - reason=解析响应失败, error={}", e.getMessage(), e); + } + + return results; + } + + /** + * 安全获取文本值 + */ + private String getTextValue(JsonNode node, String field) { + JsonNode fieldNode = node.get(field); + return (fieldNode != null && !fieldNode.isNull()) ? fieldNode.asText() : ""; + } + @Override public String getName() { return "MockBaiduWebSearchProvider"; diff --git a/assistant-agent-start/src/test/java/com/alibaba/assistant/agent/start/search/MockBaiduWebSearchProviderTest.java b/assistant-agent-start/src/test/java/com/alibaba/assistant/agent/start/search/MockBaiduWebSearchProviderTest.java new file mode 100644 index 0000000..aa67e67 --- /dev/null +++ b/assistant-agent-start/src/test/java/com/alibaba/assistant/agent/start/search/MockBaiduWebSearchProviderTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.assistant.agent.start.search; + +import com.alibaba.assistant.agent.extension.search.model.SearchRequest; +import com.alibaba.assistant.agent.extension.search.model.SearchResultItem; +import com.alibaba.assistant.agent.extension.search.model.SearchSourceType; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * MockBaiduWebSearchProvider 单元测试 + * + * @author Assistant Agent Team + */ +class MockBaiduWebSearchProviderTest { + + @Test + @DisplayName("测试简单模式构造 - null API Key应使用简单模式") + void testConstructor_withNullApiKey_shouldUseSimpleMode() { + MockBaiduWebSearchProvider p = new MockBaiduWebSearchProvider(null); + SearchRequest request = new SearchRequest("Spring AI Alibaba"); + request.setTopK(3); + + List results = p.search(request); + + assertNotNull(results); + assertFalse(results.isEmpty()); + } + + @Test + @DisplayName("集成测试 - 使用真实API Key") + @EnabledIfEnvironmentVariable(named = "QIANFAN_API_KEY", matches = ".+") + void testRealApiIntegration() { + String realApiKey = System.getenv("QIANFAN_API_KEY"); + MockBaiduWebSearchProvider realProvider = new MockBaiduWebSearchProvider(realApiKey); + + SearchRequest request = new SearchRequest("Spring AI Alibaba"); + request.setTopK(5); + + List results = realProvider.search(request); + + assertNotNull(results); + assertFalse(results.isEmpty(), "真实API应返回搜索结果"); + + System.out.println("=== 真实API返回结果 ==="); + System.out.println("结果数量: " + results.size()); + + for (SearchResultItem item : results) { + System.out.println("---"); + System.out.println("标题: " + item.getTitle()); + System.out.println("URL: " + item.getUri()); + System.out.println("评分: " + item.getScore()); + String snippet = item.getSnippet(); + if (snippet != null && snippet.length() > 100) { + snippet = snippet.substring(0, 100) + "..."; + } + System.out.println("摘要: " + snippet); + } + + // 验证第一个结果的完整性 + SearchResultItem first = results.get(0); + assertNotNull(first.getId()); + assertEquals(SearchSourceType.WEB, first.getSourceType()); + assertNotNull(first.getTitle()); + assertFalse(first.getTitle().isEmpty()); + } +}