diff --git a/api/.ruff.toml b/api/.ruff.toml index f30275a943d806..89a2da35d6cf87 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -85,11 +85,11 @@ ignore = [ ] "tests/*" = [ "F811", # redefined-while-unused - "F401", # unused-import ] [lint.pyflakes] -extend-generics = [ +allowed-unused-imports = [ "_pytest.monkeypatch", "tests.integration_tests", + "tests.unit_tests", ] diff --git a/api/Dockerfile b/api/Dockerfile index b5b8f69829d918..df676f19261133 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,7 +55,7 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ + && apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.19+dfsg-1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ # install a chinese font to support the use of tools like matplotlib && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 26a3a022d401a4..6942ac6fbe62dd 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -2,7 +2,7 @@ import logging from flask import abort, request -from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -14,7 +14,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from factories import variable_factory -from fields.workflow_fields import workflow_fields +from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.helper import TimestampField, uuid_value @@ -440,6 +440,31 @@ def get(self, app_model: App): } +class PublishedAllWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_pagination_fields) + def get(self, app_model: App): + """ + Get published workflows + """ + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + args = parser.parse_args() + page = args.get("page") + limit = args.get("limit") + workflow_service = WorkflowService() + workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit) + + return {"items": workflows, "page": page, "limit": limit, "has_more": has_more} + + api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") api.add_resource(WorkflowConfigApi, "/apps//workflows/draft/config") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") @@ -454,6 +479,7 @@ def get(self, app_model: App): WorkflowDraftRunIterationNodeApi, "/apps//workflows/draft/iteration/nodes//run" ) api.add_resource(PublishedWorkflowApi, "/apps//workflows/publish") +api.add_resource(PublishedAllWorkflowApi, "/apps//workflows") api.add_resource(DefaultBlockConfigsApi, "/apps//workflows/default-workflow-block-configs") api.add_resource( DefaultBlockConfigApi, "/apps//workflows/default-workflow-block-configs/" diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 690297048eb55c..405d5ed607f639 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -66,10 +66,17 @@ def post(self, installed_app, message_id): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args.get("rating"), args.get("content")) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=current_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 1afb41ea87660c..a2b41c1d38f87d 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -122,7 +122,7 @@ def put(self, member_id): return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) - if member: + if not member: abort(404) try: diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index bed89a99a58683..773ea0e0c69385 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -108,7 +108,13 @@ def post(self, app_model: App, end_user: EndUser, message_id): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args.get("rating"), args.get("content")) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=end_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 84c58c62df5b3c..ea664b8f1be4d4 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -8,12 +8,16 @@ import services.dataset_service from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api -from controllers.service_api.app.error import ProviderNotInitializeError +from controllers.service_api.app.error import ( + FileTooLargeError, + NoFileUploadedError, + ProviderNotInitializeError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, DocumentIndexingError, - NoFileUploadedError, - TooManyFilesError, ) from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from core.errors.error import ProviderTokenNotInitError @@ -238,13 +242,18 @@ def post(self, tenant_id, dataset_id, document_id): if not file.filename: raise FilenameNotExistsError - upload_file = FileService.upload_file( - filename=file.filename, - content=file.read(), - mimetype=file.mimetype, - user=current_user, - source="datasets", - ) + try: + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 8d69bdcec2c2ac..ae086ba8ed66ec 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -339,13 +339,13 @@ def save_agent_thought( raise ValueError(f"Agent thought {agent_thought.id} not found") agent_thought = queried_thought - if thought is not None: + if thought: agent_thought.thought = thought - if tool_name is not None: + if tool_name: agent_thought.tool = tool_name - if tool_input is not None: + if tool_input: if isinstance(tool_input, dict): try: tool_input = json.dumps(tool_input, ensure_ascii=False) @@ -354,7 +354,7 @@ def save_agent_thought( agent_thought.tool_input = tool_input - if observation is not None: + if observation: if isinstance(observation, dict): try: observation = json.dumps(observation, ensure_ascii=False) @@ -363,7 +363,7 @@ def save_agent_thought( agent_thought.observation = observation - if answer is not None: + if answer: agent_thought.answer = answer if messages_ids is not None and len(messages_ids) > 0: diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index c77a499982e98b..1247a11fe8a400 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -122,6 +122,7 @@ class _CommonWenxin: "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", "bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base", + "ernie-lite-pro-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-pro-128k", } function_calling_supports = [ diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-pro-128k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-pro-128k.yaml new file mode 100644 index 00000000000000..4f5832c8598926 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-pro-128k.yaml @@ -0,0 +1,42 @@ +model: ernie-lite-pro-128k +label: + en_US: Ernie-Lite-Pro-128K +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.8 + - name: top_p + use_template: top_p + - name: min_output_tokens + label: + en_US: "Min Output Tokens" + zh_Hans: "最小输出Token数" + use_template: max_tokens + min: 2 + max: 2048 + help: + zh_Hans: 指定模型最小输出token数 + en_US: Specifies the lower limit on the length of generated results. + - name: max_output_tokens + label: + en_US: "Max Output Tokens" + zh_Hans: "最大输出Token数" + use_template: max_tokens + min: 2 + max: 2048 + default: 2048 + help: + zh_Hans: 指定模型最大输出token数 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 8b17e8dc0a3762..a6214d955b1ddd 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,5 +1,5 @@ import re -from typing import Optional +from typing import Optional, cast class JiebaKeywordTableHandler: @@ -8,18 +8,20 @@ def __init__(self): from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS - jieba.analyse.default_tfidf.stop_words = STOPWORDS + jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" - import jieba # type: ignore + import jieba.analyse # type: ignore keywords = jieba.analyse.extract_tags( sentence=text, topK=max_keywords_per_chunk, ) + # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + keywords = cast(list[str], keywords) - return set(self._expand_tokens_with_subtokens(keywords)) + return set(self._expand_tokens_with_subtokens(set(keywords))) def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index fdc2e46d141d07..41355d3fac1234 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -138,17 +138,24 @@ def _get_notion_block_data(self, page_id: str) -> list[str]: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) while True: query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} - res = requests.request( - "GET", - block_url, - headers={ - "Authorization": "Bearer " + self._notion_access_token, - "Content-Type": "application/json", - "Notion-Version": "2022-06-28", - }, - params=query_dict, - ) - data = res.json() + try: + res = requests.request( + "GET", + block_url, + headers={ + "Authorization": "Bearer " + self._notion_access_token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + }, + params=query_dict, + ) + if res.status_code != 200: + raise ValueError(f"Error fetching Notion block data: {res.text}") + data = res.json() + except requests.RequestException as e: + raise ValueError("Error fetching Notion block data") from e + if "results" not in data or not isinstance(data["results"], list): + raise ValueError("Error fetching Notion block data") for result in data["results"]: result_type = result["type"] result_obj = result[result_type] diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py index 050b468b740c27..aca369b4389787 100644 --- a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py @@ -21,7 +21,7 @@ def _bedrock_retrieve( retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} - # 如果有元数据过滤条件,则添加到检索配置中 + # Add metadata filter to retrieval configuration if present if metadata_filter: retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter @@ -77,7 +77,7 @@ def _invoke( if not query: return self.create_text_message("Please input query") - # 获取元数据过滤条件(如果存在) + # Get metadata filter conditions (if they exist) metadata_filter_str = tool_parameters.get("metadata_filter") metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None @@ -86,7 +86,7 @@ def _invoke( query_input=query, knowledge_base_id=self.knowledge_base_id, num_results=self.topk, - metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法 + metadata_filter=metadata_filter, ) line = 5 @@ -109,7 +109,7 @@ def validate_parameters(self, parameters: dict[str, Any]) -> None: if not parameters.get("query"): raise ValueError("query is required") - # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供) + # Optional: Validate if metadata filter is a valid JSON string (if provided) metadata_filter_str = parameters.get("metadata_filter") if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict): raise ValueError("metadata_filter must be a valid JSON object") diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml index 9e51d52def4037..31961a0cf03024 100644 --- a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml @@ -73,9 +73,9 @@ parameters: llm_description: AWS region where the Bedrock Knowledge Base is located form: form - - name: metadata_filter - type: string - required: false + - name: metadata_filter # Additional parameter for metadata filtering + type: string # String type, expects JSON-formatted filter conditions + required: false # Optional field - can be omitted label: en_US: Metadata Filter zh_Hans: 元数据过滤器 diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py index e05e2d9bf7d356..3d88f28dbd2fc7 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py @@ -6,8 +6,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -# 定义标签映射 -LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"} +# Define label mappings +LABEL_MAPPING = {0: "SAFE", 1: "NO_SAFE"} class ContentModerationTool(BuiltinTool): @@ -28,12 +28,12 @@ def _invoke_sagemaker(self, payload: dict, endpoint: str): # Handle nested JSON if present if isinstance(json_obj, dict) and "body" in json_obj: body_content = json.loads(json_obj["body"]) - raw_label = body_content.get("label") + prediction_result = body_content.get("prediction") else: - raw_label = json_obj.get("label") + prediction_result = json_obj.get("prediction") - # 映射标签并返回 - result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE + # Map labels and return + result = LABEL_MAPPING.get(prediction_result, "NO_SAFE") # If not found in mapping, default to NO_SAFE return result def _invoke( diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index 715b1ddeddcae5..8320bd84efa440 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -10,8 +10,7 @@ class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str | None = None - topk: int | None = None + sagemaker_endpoint: str = None def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): inputs = [query_input] * len(docs) @@ -47,8 +46,7 @@ def _invoke( self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") line = 2 - if not self.topk: - self.topk = tool_parameters.get("topk", 5) + topk = tool_parameters.get("topk", 5) line = 3 query = tool_parameters.get("query", "") @@ -75,7 +73,7 @@ def _invoke( sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 9 - return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] + return [self.create_json_message(res) for res in sorted_candidate_docs[:topk]] except Exception as e: return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py index f994cdbf66e78b..2bf10ce8ff2632 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -125,7 +125,7 @@ def generate_image_by_prompt(self, prompt: dict) -> list[bytes]: for output in history["outputs"].values(): for img in output.get("images", []): image_data = self.get_image(img["filename"], img["subfolder"], img["type"]) - images.append(image_data) + images.append((image_data, img["filename"])) return images finally: ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py index 87837362779baa..eb085f221ebdda 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py @@ -1,4 +1,5 @@ import json +import mimetypes from typing import Any from core.file import FileType @@ -75,10 +76,12 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe images = comfyui.generate_image_by_prompt(prompt) result = [] - for img in images: + for image_data, filename in images: result.append( self.create_blob_message( - blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + blob=image_data, + meta={"mime_type": mimetypes.guess_type(filename)[0]}, + save_as=self.VariableKey.IMAGE.value, ) ) return result diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index edff4a2d07cca2..eea66ee4ed7447 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -1,12 +1,13 @@ import json import logging from copy import deepcopy -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.tool.tool import Tool from extensions.ext_database import db +from factories.file_factory import build_from_mapping from models.account import Account from models.model import App, EndUser from models.workflow import Workflow @@ -194,10 +195,18 @@ def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: if isinstance(value, list): for item in value: if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY: - file = File.model_validate(item) + item["tool_file_id"] = item.get("related_id") + file = build_from_mapping( + mapping=item, + tenant_id=str(cast(Tool.Runtime, self.runtime).tenant_id), + ) files.append(file) elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - file = File.model_validate(value) + value["tool_file_id"] = value.get("related_id") + file = build_from_mapping( + mapping=value, + tenant_id=str(cast(Tool.Runtime, self.runtime).tenant_id), + ) files.append(file) result[key] = value diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index b3bcc3b2ccc309..5c672c985b6a1f 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -613,10 +613,10 @@ def _fetch_all_node_ids_in_parallels( for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): # check which node is after if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): - if node_id in merge_branch_node_ids: + if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id2] elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): - if node_id2 in merge_branch_node_ids: + if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id] branches_merge_node_ids: dict[str, str] = {} diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 408ed31096d2a0..14396e9920a2c2 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -15,11 +15,11 @@ def handle(sender, **kwargs): app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids: set[int] = set() + removed_dataset_ids: set[str] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids: set[int] = set() + old_dataset_ids: set[str] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -39,8 +39,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[int]: - dataset_ids: set[int] = set() +def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[str]: + dataset_ids: set[str] = set() if not app_model_config: return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 7a31c82f6adbc2..dd2efed94bca7f 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -17,11 +17,11 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_workflow(published_workflow) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids: set[int] = set() + removed_dataset_ids: set[str] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids: set[int] = set() + old_dataset_ids: set[str] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -41,8 +41,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[int]: - dataset_ids: set[int] = set() +def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: + dataset_ids: set[str] = set() graph = published_workflow.graph_dict if not graph: return dataset_ids @@ -60,7 +60,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[int]: for node in knowledge_retrieval_nodes: try: node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) - dataset_ids.update(int(dataset_id) for dataset_id in node_data.dataset_ids) + dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids) except Exception as e: continue diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 30f216ff95612b..26bd6b357712c9 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -69,6 +69,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: "schedule.create_tidb_serverless_task", "schedule.update_tidb_serverless_status_task", "schedule.clean_messages", + "schedule.mail_clean_document_notify_task", ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { @@ -92,6 +93,11 @@ def __call__(self, *args: object, **kwargs: object) -> object: "task": "schedule.clean_messages.clean_messages", "schedule": timedelta(days=day), }, + # every Monday + "mail_clean_document_notify_task": { + "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", + "schedule": crontab(minute="0", hour="10", day_of_week="1"), + }, } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index bd093d4063bc2e..32f979a5f2aa08 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -45,6 +45,7 @@ def format(self, value): "graph": fields.Raw(attribute="graph_dict"), "features": fields.Raw(attribute="features_dict"), "hash": fields.String(attribute="unique_hash"), + "version": fields.String(attribute="version"), "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), "created_at": TimestampField, "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), @@ -61,3 +62,10 @@ def format(self, value): "updated_by": fields.String, "updated_at": TimestampField, } + +workflow_pagination_fields = { + "items": fields.List(fields.Nested(workflow_fields), attribute="items"), + "page": fields.Integer, + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), +} diff --git a/api/migrations/README b/api/migrations/README index 220678df7ab06e..0e048441597444 100644 --- a/api/migrations/README +++ b/api/migrations/README @@ -1,2 +1 @@ Single-database configuration for Flask. - diff --git a/api/models/workflow.py b/api/models/workflow.py index 32a0860b77bbea..8642df8adb55c5 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -414,6 +414,18 @@ class WorkflowRun(db.Model): # type: ignore[name-defined] finished_at = db.Column(db.DateTime) exceptions_count = db.Column(db.Integer, server_default=db.text("0")) + @property + def created_by_account(self): + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + from models.model import EndUser + + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + @property def graph_dict(self): return json.loads(self.graph) if self.graph else {} diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 48bdc872f41e5c..5e4d3ec323e41d 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -28,7 +28,6 @@ def clean_messages(): plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING ) - page = 1 while True: try: # Main query with join and filter @@ -79,4 +78,4 @@ def clean_messages(): db.session.query(Message).filter(Message.id == message.id).delete() db.session.commit() end_at = time.perf_counter() - click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) + click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 766954a257371f..fe6839288d8503 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -3,14 +3,18 @@ from collections import defaultdict import click -from celery import shared_task # type: ignore +from flask import render_template # type: ignore +import app +from configs import dify_config +from extensions.ext_database import db from extensions.ext_mail import mail from models.account import Account, Tenant, TenantAccountJoin from models.dataset import Dataset, DatasetAutoDisableLog +from services.feature_service import FeatureService -@shared_task(queue="mail") +@app.celery.task(queue="dataset") def send_document_clean_notify_task(): """ Async Send document clean notify mail @@ -29,35 +33,58 @@ def send_document_clean_notify_task(): # group by tenant_id dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) - + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): - knowledge_details = [] - tenant = Tenant.query.filter(Tenant.id == tenant_id).first() - if not tenant: - continue - current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() - if not current_owner_join: - continue - account = Account.query.filter(Account.id == current_owner_join.account_id).first() - if not account: - continue - - dataset_auto_dataset_map = {} # type: ignore - for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( - dataset_auto_disable_log.document_id - ) + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != "sandbox": + knowledge_details = [] + # check tenant + tenant = Tenant.query.filter(Tenant.id == tenant_id).first() + if not tenant: + continue + # check current owner + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + if not current_owner_join: + continue + account = Account.query.filter(Account.id == current_owner_join.account_id).first() + if not account: + continue - for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = Dataset.query.filter(Dataset.id == dataset_id).first() - if dataset: - document_count = len(document_ids) - knowledge_details.append(f"
  • Knowledge base {dataset.name}: {document_count} documents
  • ") + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = Dataset.query.filter(Dataset.id == dataset_id).first() + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + html_content = render_template( + "clean_document_job_mail_template-US.html", + userName=account.email, + knowledge_details=knowledge_details, + url=url, + ) + mail.send( + to=account.email, subject="Dify Knowledge base auto disable notification", html=html_content + ) + + # update notified to True + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + dataset_auto_disable_log.notified = True + db.session.commit() end_at = time.perf_counter() logging.info( click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") ) except Exception: - logging.exception("Send invite member mail to failed") + logging.exception("Send document clean notify mail failed") diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 528a0dbcd39d9e..b6d6d05e589e92 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -176,6 +176,9 @@ def import_app( data["kind"] = "app" imported_version = data.get("version", "0.1.0") + # check if imported_version is a float-like string + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") status = _check_version_compatibility(imported_version) # Extract app data diff --git a/api/services/audio_service.py b/api/services/audio_service.py index ef52301c0aed5d..f4178a69a4aada 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -139,7 +139,7 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): return Response(stream_with_context(response), content_type="audio/mpeg") return response else: - if not text: + if text is None: raise ValueError("Text is required") response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 1fd18568f54126..4821eb66969639 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -86,25 +86,30 @@ def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids else: return [], 0 else: - # show all datasets that the user has permission to access - if permitted_dataset_ids: - query = query.filter( - db.or_( - Dataset.permission == DatasetPermissionEnum.ALL_TEAM, - db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), - db.and_( - Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, - Dataset.id.in_(permitted_dataset_ids), - ), + if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): + # show all datasets that the user has permission to access + if permitted_dataset_ids: + query = query.filter( + db.or_( + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_( + Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id + ), + db.and_( + Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, + Dataset.id.in_(permitted_dataset_ids), + ), + ) ) - ) - else: - query = query.filter( - db.or_( - Dataset.permission == DatasetPermissionEnum.ALL_TEAM, - db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), + else: + query = query.filter( + db.or_( + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_( + Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id + ), + ) ) - ) else: # if no user, only show datasets that are shared with all team members query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) @@ -377,14 +382,19 @@ def check_dataset_permission(dataset, user): if dataset.tenant_id != user.current_tenant_id: logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") raise NoPermissionError("You do not have permission to access this dataset.") - if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: - logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") - raise NoPermissionError("You do not have permission to access this dataset.") - if dataset.permission == "partial_members": - user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() - if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id: + if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): + if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == "partial_members": + user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() + if ( + not user_permission + and dataset.tenant_id != user.current_tenant_id + and dataset.created_by != user.id + ): + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): @@ -394,15 +404,16 @@ def check_dataset_operator_permission(user: Optional[Account] = None, dataset: O if not user: raise ValueError("User not found") - if dataset.permission == DatasetPermissionEnum.ONLY_ME: - if dataset.created_by != user.id: - raise NoPermissionError("You do not have permission to access this dataset.") + if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): + if dataset.permission == DatasetPermissionEnum.ONLY_ME: + if dataset.created_by != user.id: + raise NoPermissionError("You do not have permission to access this dataset.") - elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: - if not any( - dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() - ): - raise NoPermissionError("You do not have permission to access this dataset.") + elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: + if not any( + dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() + ): + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): @@ -441,7 +452,7 @@ def get_dataset_auto_disable_logs(dataset_id: str) -> dict: class DocumentService: - DEFAULT_RULES = { + DEFAULT_RULES: dict[str, Any] = { "mode": "custom", "rules": { "pre_processing_rules": [ @@ -455,7 +466,7 @@ class DocumentService: }, } - DOCUMENT_METADATA_SCHEMA = { + DOCUMENT_METADATA_SCHEMA: dict[str, Any] = { "book": { "title": str, "language": str, diff --git a/api/services/errors/base.py b/api/services/errors/base.py index 4d39f956b8c932..35ea28468e0d86 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,6 +1,6 @@ from typing import Optional -class BaseServiceError(Exception): +class BaseServiceError(ValueError): def __init__(self, description: Optional[str] = None): self.description = description diff --git a/api/services/message_service.py b/api/services/message_service.py index c4447a84da5e09..c17122ef647ecd 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -152,6 +152,7 @@ def pagination_by_last_id( @classmethod def create_feedback( cls, + *, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 0e3bd3a7b83c68..988f9df927e4f7 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -425,7 +425,7 @@ def test_api_tool_preview( "tenant_id": tenant_id, } ) - result = tool.validate_credentials(credentials, parameters) + result = runtime_tool.validate_credentials(credentials, parameters) except Exception as e: return {"error": str(e)} diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 81b197a2478992..9f7a9c770d9306 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,8 @@ from typing import Any, Optional, cast from uuid import uuid4 +from sqlalchemy import desc + from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.model_runtime.utils.encoders import jsonable_encoder @@ -76,6 +78,28 @@ def get_published_workflow(self, app_model: App) -> Optional[Workflow]: return workflow + def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]: + """ + Get published workflow with pagination + """ + if not app_model.workflow_id: + return [], False + + workflows = ( + db.session.query(Workflow) + .filter(Workflow.app_id == app_model.id) + .order_by(desc(Workflow.version)) + .offset((page - 1) * limit) + .limit(limit + 1) + .all() + ) + + has_more = len(workflows) > limit + if has_more: + workflows = workflows[:-1] + + return workflows, has_more + def sync_draft_workflow( self, *, diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 9a172b2d9d8157..bd7fcdadeaa374 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -38,7 +38,11 @@ def add_document_to_index_task(dataset_document_id: str): try: segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == False, + DocumentSegment.status == "completed", + ) .order_by(DocumentSegment.position.asc()) .all() ) @@ -85,6 +89,16 @@ def add_document_to_index_task(dataset_document_id: str): db.session.query(DatasetAutoDisableLog).filter( DatasetAutoDisableLog.document_id == dataset_document.id ).delete() + + # update segment to enable + db.session.query(DocumentSegment).filter(DocumentSegment.document_id == dataset_document.id).update( + { + DocumentSegment.enabled: True, + DocumentSegment.disabled_at: None, + DocumentSegment.disabled_by: None, + DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) db.session.commit() end_at = time.perf_counter() diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 1d580b38028f37..d0c4382f58d75a 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -1,3 +1,4 @@ +import datetime import logging import time @@ -46,6 +47,16 @@ def remove_document_from_index_task(document_id: str): index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) except Exception: logging.exception(f"clean dataset {dataset.id} from index failed") + # update segment to disable + db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).update( + { + DocumentSegment.enabled: False, + DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.disabled_by: document.disabled_by, + DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) + db.session.commit() end_at = time.perf_counter() logging.info( diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html index b7c9538f9f8bee..88e78f41c78b46 100644 --- a/api/templates/clean_document_job_mail_template-US.html +++ b/api/templates/clean_document_job_mail_template-US.html @@ -45,14 +45,14 @@ .content ul li { margin-bottom: 10px; } - .cta-button { + .cta-button, .cta-button:hover, .cta-button:active, .cta-button:visited, .cta-button:focus { display: block; margin: 20px auto; padding: 10px 20px; background-color: #4e89f9; - color: #ffffff; + color: #ffffff !important; text-align: center; - text-decoration: none; + text-decoration: none !important; border-radius: 5px; width: fit-content; } @@ -69,7 +69,7 @@