From 562450751f9d63ef8cdacdb3f757db888abca4c9 Mon Sep 17 00:00:00 2001 From: Warren Chen Date: Mon, 30 Dec 2024 22:26:04 +0800 Subject: [PATCH] [Fix] Fix sagemaker_chinese_toxicity_detector and bedrock_retrieve (#12227) --- .../provider/builtin/aws/tools/bedrock_retrieve.py | 8 ++++---- .../provider/builtin/aws/tools/bedrock_retrieve.yaml | 6 +++--- .../aws/tools/sagemaker_chinese_toxicity_detector.py | 12 ++++++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py index 050b468b740c27..aca369b4389787 100644 --- a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py @@ -21,7 +21,7 @@ def _bedrock_retrieve( retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} - # 如果有元数据过滤条件,则添加到检索配置中 + # Add metadata filter to retrieval configuration if present if metadata_filter: retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter @@ -77,7 +77,7 @@ def _invoke( if not query: return self.create_text_message("Please input query") - # 获取元数据过滤条件(如果存在) + # Get metadata filter conditions (if they exist) metadata_filter_str = tool_parameters.get("metadata_filter") metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None @@ -86,7 +86,7 @@ def _invoke( query_input=query, knowledge_base_id=self.knowledge_base_id, num_results=self.topk, - metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法 + metadata_filter=metadata_filter, ) line = 5 @@ -109,7 +109,7 @@ def validate_parameters(self, parameters: dict[str, Any]) -> None: if not parameters.get("query"): raise ValueError("query is required") - # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供) + # Optional: Validate if metadata filter is a valid JSON string (if provided) metadata_filter_str = parameters.get("metadata_filter") if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict): raise ValueError("metadata_filter must be a valid JSON object") diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml index 9e51d52def4037..d4c5c5f64a748d 100644 --- a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml @@ -73,9 +73,9 @@ parameters: llm_description: AWS region where the Bedrock Knowledge Base is located form: form - - name: metadata_filter - type: string - required: false + - name: metadata_filter # 新增的元数据过滤参数 + type: string # 可以是字符串类型,包含 JSON 格式的过滤条件 + required: false # 可选参数 label: en_US: Metadata Filter zh_Hans: 元数据过滤器 diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py index e05e2d9bf7d356..3d88f28dbd2fc7 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py @@ -6,8 +6,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -# 定义标签映射 -LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"} +# Define label mappings +LABEL_MAPPING = {0: "SAFE", 1: "NO_SAFE"} class ContentModerationTool(BuiltinTool): @@ -28,12 +28,12 @@ def _invoke_sagemaker(self, payload: dict, endpoint: str): # Handle nested JSON if present if isinstance(json_obj, dict) and "body" in json_obj: body_content = json.loads(json_obj["body"]) - raw_label = body_content.get("label") + prediction_result = body_content.get("prediction") else: - raw_label = json_obj.get("label") + prediction_result = json_obj.get("prediction") - # 映射标签并返回 - result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE + # Map labels and return + result = LABEL_MAPPING.get(prediction_result, "NO_SAFE") # If not found in mapping, default to NO_SAFE return result def _invoke(