Skip to content

Commit 7b03a03

Browse files
authored
fix: better memory usage from 800+ to 500+ (#11796)
Signed-off-by: yihong0618 <[email protected]>
1 parent 52201d9 commit 7b03a03

File tree

5 files changed

+56
-26
lines changed

5 files changed

+56
-26
lines changed

api/core/model_runtime/model_providers/vertex_ai/llm/llm.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
import logging
55
import time
66
from collections.abc import Generator
7-
from typing import Optional, Union, cast
7+
from typing import TYPE_CHECKING, Optional, Union, cast
88

99
import google.auth.transport.requests
1010
import requests
11-
import vertexai.generative_models as glm
1211
from anthropic import AnthropicVertex, Stream
1312
from anthropic.types import (
1413
ContentBlockDeltaEvent,
@@ -19,8 +18,6 @@
1918
MessageStreamEvent,
2019
)
2120
from google.api_core import exceptions
22-
from google.cloud import aiplatform
23-
from google.oauth2 import service_account
2421
from PIL import Image
2522

2623
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -47,6 +44,9 @@
4744
from core.model_runtime.errors.validate import CredentialsValidateFailedError
4845
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
4946

47+
if TYPE_CHECKING:
48+
import vertexai.generative_models as glm
49+
5050
logger = logging.getLogger(__name__)
5151

5252

@@ -102,6 +102,8 @@ def _generate_anthropic(
102102
:param stream: is stream response
103103
:return: full response or stream response chunk generator result
104104
"""
105+
from google.oauth2 import service_account
106+
105107
# use Anthropic official SDK references
106108
# - https://github.com/anthropics/anthropic-sdk-python
107109
service_account_key = credentials.get("vertex_service_account_key", "")
@@ -406,13 +408,15 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
406408

407409
return text.rstrip()
408410

409-
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
411+
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
410412
"""
411413
Convert tool messages to glm tools
412414
413415
:param tools: tool messages
414416
:return: glm tools
415417
"""
418+
import vertexai.generative_models as glm
419+
416420
return glm.Tool(
417421
function_declarations=[
418422
glm.FunctionDeclaration(
@@ -473,6 +477,10 @@ def _generate(
473477
:param user: unique user id
474478
:return: full response or stream response chunk generator result
475479
"""
480+
import vertexai.generative_models as glm
481+
from google.cloud import aiplatform
482+
from google.oauth2 import service_account
483+
476484
config_kwargs = model_parameters.copy()
477485
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
478486

@@ -522,7 +530,7 @@ def _generate(
522530
return self._handle_generate_response(model, credentials, response, prompt_messages)
523531

524532
def _handle_generate_response(
525-
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
533+
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
526534
) -> LLMResult:
527535
"""
528536
Handle llm response
@@ -554,7 +562,7 @@ def _handle_generate_response(
554562
return result
555563

556564
def _handle_generate_stream_response(
557-
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
565+
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
558566
) -> Generator:
559567
"""
560568
Handle llm stream response
@@ -638,13 +646,15 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:
638646

639647
return message_text
640648

641-
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
649+
def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
642650
"""
643651
Format a single message into glm.Content for Google API
644652
645653
:param message: one PromptMessage
646654
:return: glm Content representation of message
647655
"""
656+
import vertexai.generative_models as glm
657+
648658
if isinstance(message, UserPromptMessage):
649659
glm_content = glm.Content(role="user", parts=[])
650660

api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@
22
import json
33
import time
44
from decimal import Decimal
5-
from typing import Optional
5+
from typing import TYPE_CHECKING, Optional
66

77
import tiktoken
8-
from google.cloud import aiplatform
9-
from google.oauth2 import service_account
10-
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
118

129
from core.entities.embedding_type import EmbeddingInputType
1310
from core.model_runtime.entities.common_entities import I18nObject
@@ -24,6 +21,11 @@
2421
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
2522
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
2623

24+
if TYPE_CHECKING:
25+
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
26+
else:
27+
VertexTextEmbeddingModel = None
28+
2729

2830
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
2931
"""
@@ -48,6 +50,10 @@ def _invoke(
4850
:param input_type: input type
4951
:return: embeddings result
5052
"""
53+
from google.cloud import aiplatform
54+
from google.oauth2 import service_account
55+
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
56+
5157
service_account_key = credentials.get("vertex_service_account_key", "")
5258
project_id = credentials["vertex_project_id"]
5359
location = credentials["vertex_location"]
@@ -100,6 +106,10 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
100106
:param credentials: model credentials
101107
:return:
102108
"""
109+
from google.cloud import aiplatform
110+
from google.oauth2 import service_account
111+
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
112+
103113
try:
104114
service_account_key = credentials.get("vertex_service_account_key", "")
105115
project_id = credentials["vertex_project_id"]

api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import re
22
from typing import Optional
33

4-
import jieba
5-
from jieba.analyse import default_tfidf
6-
7-
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
8-
94

105
class JiebaKeywordTableHandler:
116
def __init__(self):
12-
default_tfidf.stop_words = STOPWORDS
7+
import jieba.analyse
8+
9+
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
10+
11+
jieba.analyse.default_tfidf.stop_words = STOPWORDS
1312

1413
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
1514
"""Extract keywords with JIEBA tfidf."""
15+
import jieba
16+
1617
keywords = jieba.analyse.extract_tags(
1718
sentence=text,
1819
topK=max_keywords_per_chunk,
@@ -22,6 +23,8 @@ def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10
2223

2324
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
2425
"""Get subtokens from a list of tokens., filtering for stopwords."""
26+
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
27+
2528
results = set()
2629
for token in tokens:
2730
results.add(token)

api/core/rag/datasource/vdb/oracle/oraclevector.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
from typing import Any
77

88
import jieba.posseg as pseg
9-
import nltk
109
import numpy
1110
import oracledb
12-
from nltk.corpus import stopwords
1311
from pydantic import BaseModel, model_validator
1412

1513
from configs import dify_config
@@ -202,6 +200,10 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
202200
return docs
203201

204202
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
203+
# lazy import
204+
import nltk
205+
from nltk.corpus import stopwords
206+
205207
top_k = kwargs.get("top_k", 5)
206208
# just not implement fetch by score_threshold now, may be later
207209
score_threshold = float(kwargs.get("score_threshold") or 0.0)

api/core/workflow/nodes/document_extractor/node.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,6 @@
88
import pandas as pd
99
import pypdfium2 # type: ignore
1010
import yaml # type: ignore
11-
from unstructured.partition.api import partition_via_api
12-
from unstructured.partition.email import partition_email
13-
from unstructured.partition.epub import partition_epub
14-
from unstructured.partition.msg import partition_msg
15-
from unstructured.partition.ppt import partition_ppt
16-
from unstructured.partition.pptx import partition_pptx
1711

1812
from configs import dify_config
1913
from core.file import File, FileTransferMethod, file_manager
@@ -256,6 +250,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
256250

257251

258252
def _extract_text_from_ppt(file_content: bytes) -> str:
253+
from unstructured.partition.ppt import partition_ppt
254+
259255
try:
260256
with io.BytesIO(file_content) as file:
261257
elements = partition_ppt(file=file)
@@ -265,6 +261,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
265261

266262

267263
def _extract_text_from_pptx(file_content: bytes) -> str:
264+
from unstructured.partition.api import partition_via_api
265+
from unstructured.partition.pptx import partition_pptx
266+
268267
try:
269268
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
270269
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
@@ -287,6 +286,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
287286

288287

289288
def _extract_text_from_epub(file_content: bytes) -> str:
289+
from unstructured.partition.epub import partition_epub
290+
290291
try:
291292
with io.BytesIO(file_content) as file:
292293
elements = partition_epub(file=file)
@@ -296,6 +297,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
296297

297298

298299
def _extract_text_from_eml(file_content: bytes) -> str:
300+
from unstructured.partition.email import partition_email
301+
299302
try:
300303
with io.BytesIO(file_content) as file:
301304
elements = partition_email(file=file)
@@ -305,6 +308,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
305308

306309

307310
def _extract_text_from_msg(file_content: bytes) -> str:
311+
from unstructured.partition.msg import partition_msg
312+
308313
try:
309314
with io.BytesIO(file_content) as file:
310315
elements = partition_msg(file=file)

0 commit comments

Comments
 (0)