Skip to content

Commit 47e30ba

Browse files
committed
feat(demohouse/shopping): mock vdb
1 parent e493786 commit 47e30ba

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

demohouse/shopping/backend/code/main.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
import json
22
import logging
33
import os
4+
import time
45
from typing import AsyncIterable
56
from arkitect.launcher.local.serve import launch_serve
67
from arkitect.telemetry.trace import task
7-
from arkitect.types.llm.model import ArkChatRequest, ArkChatParameters
8+
from arkitect.types.llm.model import (
9+
ArkChatRequest,
10+
ArkChatParameters,
11+
ArkChatCompletionChunk,
12+
BotUsage,
13+
ActionDetail,
14+
ToolDetail,
15+
)
816
from volcenginesdkarkruntime.types.chat import ChatCompletionChunk
917

1018
from arkitect.core.component.context.context import Context
1119

1220
from arkitect.core.runtime import Response
13-
from volcenginesdkarkruntime.types.chat.chat_completion_chunk import Choice, ChoiceDelta
1421

1522
from arkitect.core.component.context.model import State
1623

@@ -19,13 +26,37 @@
1926

2027
logger = logging.getLogger(__name__)
2128

22-
DOUBAO_VLM_ENDPOINT = "doubao-1-5-vision-pro-32k-250115"
29+
DOUBAO_VLM_ENDPOINT = "doubao-1-5-pro-32k-250115"
30+
31+
32+
@task()
33+
def get_vector_search_result_chunk(result: str) -> ArkChatCompletionChunk:
34+
return ArkChatCompletionChunk(
35+
id="",
36+
created=int(time.time()),
37+
model="",
38+
object="chat.completion.chunk",
39+
choices=[],
40+
bot_usage=BotUsage(
41+
action_details=[
42+
ActionDetail(
43+
name="vector_search",
44+
count=1,
45+
tool_details=[
46+
ToolDetail(
47+
name="vector_search", input=None, output=json.loads(result)
48+
)
49+
],
50+
)
51+
]
52+
),
53+
)
2354

2455

2556
@task()
2657
async def default_model_calling(
2758
request: ArkChatRequest,
28-
) -> AsyncIterable[ChatCompletionChunk]:
59+
) -> AsyncIterable[ArkChatCompletionChunk | ChatCompletionChunk]:
2960
parameters = ArkChatParameters(**request.__dict__)
3061
image_urls = [
3162
content.get("image_url", {}).get("url", "")
@@ -36,6 +67,12 @@ async def default_model_calling(
3667
]
3768
image_url = image_urls[-1] if len(image_urls) > 0 else ""
3869

70+
# only search for relevant product
71+
if request.metadata and request.metadata.get("search"):
72+
result = await vector_search("", image_url)
73+
yield get_vector_search_result_chunk(result)
74+
return
75+
3976
async def modify_url_hook(
4077
state: State, param: ChatCompletionMessageToolCallParam
4178
) -> ChatCompletionMessageToolCallParam:
@@ -55,12 +92,8 @@ async def modify_url_hook(
5592
async for chunk in stream:
5693
if tool_call and chunk.choices:
5794
tool_result = ctx.get_latest_message()
58-
chunk.choices.append(
59-
Choice(
60-
role="tool",
61-
delta=ChoiceDelta(content=tool_result.get("content")),
62-
index=len(chunk.choices),
63-
)
95+
yield get_vector_search_result_chunk(
96+
str(tool_result.get("content", ""))
6497
)
6598
tool_call = False
6699
yield chunk

demohouse/shopping/backend/code/vdb.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ async def vector_search(text: str, image_url: str) -> str:
4848
image_url: 固定填写为<image_url>
4949
"""
5050
client = AsyncArk(timeout=Timeout(connect=1.0, timeout=60.0))
51-
embedding_input = [MultimodalEmbeddingContentPartTextParam(type="text", text=text)]
51+
embedding_input = []
52+
if text != "":
53+
embedding_input = [MultimodalEmbeddingContentPartTextParam(type="text", text=text)]
5254
if image_url != "":
5355
embedding_input.append(
5456
MultimodalEmbeddingContentPartImageParam(
@@ -74,7 +76,7 @@ async def vector_search(text: str, image_url: str) -> str:
7476
"子类别": item.get("sub_category", ""),
7577
"价格": item.get("price", "99"),
7678
"销量": item.get("sales", "999"),
77-
"商品链接": tos_client.pre_signed_url(
79+
"图片链接": tos_client.pre_signed_url(
7880
http_method=HttpMethodType.Http_Method_Get,
7981
bucket="shopping",
8082
key=item.get("key", ""),

0 commit comments

Comments
 (0)