Skip to content

Commit

Permalink
[Fix] Fix sagemaker_chinese_toxicity_detector and bedrock_retrieve (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
warren830 authored Dec 30, 2024
1 parent adacd01 commit 5624507
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _bedrock_retrieve(

retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}

# 如果有元数据过滤条件,则添加到检索配置中
# Add metadata filter to retrieval configuration if present
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter

Expand Down Expand Up @@ -77,7 +77,7 @@ def _invoke(
if not query:
return self.create_text_message("Please input query")

# 获取元数据过滤条件(如果存在)
# Get metadata filter conditions (if they exist)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None

Expand All @@ -86,7 +86,7 @@ def _invoke(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
metadata_filter=metadata_filter,
)

line = 5
Expand All @@ -109,7 +109,7 @@ def validate_parameters(self, parameters: dict[str, Any]) -> None:
if not parameters.get("query"):
raise ValueError("query is required")

# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
# Optional: Validate if metadata filter is a valid JSON string (if provided)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ parameters:
llm_description: AWS region where the Bedrock Knowledge Base is located
form: form

- name: metadata_filter
type: string
required: false
- name: metadata_filter # 新增的元数据过滤参数
type: string # 可以是字符串类型,包含 JSON 格式的过滤条件
required: false # 可选参数
label:
en_US: Metadata Filter
zh_Hans: 元数据过滤器
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

# 定义标签映射
LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"}
# Define label mappings
LABEL_MAPPING = {0: "SAFE", 1: "NO_SAFE"}


class ContentModerationTool(BuiltinTool):
Expand All @@ -28,12 +28,12 @@ def _invoke_sagemaker(self, payload: dict, endpoint: str):
# Handle nested JSON if present
if isinstance(json_obj, dict) and "body" in json_obj:
body_content = json.loads(json_obj["body"])
raw_label = body_content.get("label")
prediction_result = body_content.get("prediction")
else:
raw_label = json_obj.get("label")
prediction_result = json_obj.get("prediction")

# 映射标签并返回
result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE
# Map labels and return
result = LABEL_MAPPING.get(prediction_result, "NO_SAFE") # If not found in mapping, default to NO_SAFE
return result

def _invoke(
Expand Down

0 comments on commit 5624507

Please sign in to comment.