Skip to content

Commit 6bb2c1b

Browse files
iceAndFireisFailedyuhongxiao
andauthored
feat: add stream response in audio (#37)
Co-authored-by: yuhongxiao <[email protected]>
1 parent 0cde756 commit 6bb2c1b

File tree

5 files changed

+54
-7
lines changed

5 files changed

+54
-7
lines changed

src/zai/api_resource/audio/audio.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import TYPE_CHECKING, Mapping, Optional, cast
44

55
import httpx
6-
from httpx import stream
76

87
from zai.core import (
98
NOT_GIVEN,
@@ -23,6 +22,8 @@
2322
from zai.types.sensitive_word_check import SensitiveWordCheckRequest
2423

2524
from .transcriptions import Transcriptions
25+
from zai.core._streaming import StreamResponse
26+
from zai.types.audio import AudioSpeechChunk
2627

2728
if TYPE_CHECKING:
2829
from zai._client import ZaiClient
@@ -60,7 +61,7 @@ def speech(
6061
speed: float | None = 1.0,
6162
volume: float | None = 1.0,
6263
stream: bool | None = False
63-
) -> HttpxBinaryResponseContent:
64+
) -> HttpxBinaryResponseContent | StreamResponse[AudioSpeechChunk]:
6465
"""
6566
Generate speech audio from text input
6667
@@ -83,7 +84,6 @@ def speech(
8384
'voice': voice,
8485
'response_format': response_format,
8586
'encode_format': encode_format,
86-
'sensitive_word_check': sensitive_word_check,
8787
'request_id': request_id,
8888
'user_id': user_id,
8989
'speed': speed,
@@ -96,6 +96,8 @@ def speech(
9696
body=maybe_transform(body, AudioSpeechParams),
9797
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
9898
cast_type=HttpxBinaryResponseContent,
99+
stream=stream or False,
100+
stream_cls=StreamResponse[AudioSpeechChunk]
99101
)
100102

101103
def customization(

src/zai/types/audio/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .audio_customization_param import AudioCustomizationParam
2+
from .audio_speech_chunk import AudioSpeechChunk
23
from .audio_speech_params import AudioSpeechParams
34
from .transcriptions_create_param import TranscriptionsParam
45

5-
__all__ = ['AudioSpeechParams', 'AudioCustomizationParam', 'TranscriptionsParam']
6+
__all__ = ['AudioSpeechParams', 'AudioCustomizationParam', 'TranscriptionsParam', 'AudioSpeechChunk']
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import List, Optional, Dict, Any
2+
3+
from ...core import BaseModel
4+
5+
__all__ = [
6+
"AudioSpeechChunk",
7+
"AudioError",
8+
"AudioSpeechChoice",
9+
"AudioSpeechDelta"
10+
]
11+
12+
13+
class AudioSpeechDelta(BaseModel):
14+
content: Optional[str] = None
15+
role: Optional[str] = None
16+
17+
18+
class AudioSpeechChoice(BaseModel):
19+
delta: AudioSpeechDelta
20+
finish_reason: Optional[str] = None
21+
index: int
22+
23+
class AudioError(BaseModel):
24+
code: Optional[str] = None
25+
message: Optional[str] = None
26+
27+
28+
class AudioSpeechChunk(BaseModel):
29+
choices: List[AudioSpeechChoice]
30+
request_id: Optional[str] = None
31+
created: Optional[int] = None
32+
error: Optional[AudioError] = None

src/zai/types/audio/audio_speech_params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,7 @@ class AudioSpeechParams(TypedDict, total=False):
2929
sensitive_word_check: Optional[SensitiveWordCheckRequest]
3030
request_id: str
3131
user_id: str
32+
encode_format: str
33+
speed: float
34+
volume: float
35+
stream: bool

tests/integration_tests/test_audio.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import base64
12
import logging
23
import logging.config
34
from pathlib import Path
45

56
import zai
67
from zai import ZaiClient
78

8-
99
def test_audio_speech(logging_conf):
1010
logging.config.dictConfig(logging_conf) # type: ignore
1111
client = ZaiClient() # Fill in your own API Key
@@ -17,11 +17,19 @@ def test_audio_speech(logging_conf):
1717
voice='female',
1818
response_format='pcm',
1919
encode_format='base64',
20-
stream=False,
20+
stream=True,
2121
speed=1.0,
2222
volume=1.0,
2323
)
24-
response.stream_to_file(speech_file_path)
24+
with open("output.pcm", "wb") as f:
25+
for item in response:
26+
choice = item.choices[0]
27+
index = choice.index
28+
finish_reason = choice.finish_reason
29+
if choice.delta is None:
30+
break
31+
audio_delta = choice.delta.content
32+
f.write(base64.b64decode(audio_delta))
2533

2634
except zai.core._errors.APIRequestFailedError as err:
2735
print(err)

0 commit comments

Comments
 (0)