diff --git a/src/zai/api_resource/audio/audio.py b/src/zai/api_resource/audio/audio.py index 2de7b47..a3cd152 100644 --- a/src/zai/api_resource/audio/audio.py +++ b/src/zai/api_resource/audio/audio.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Mapping, Optional, cast import httpx -from httpx import stream from zai.core import ( NOT_GIVEN, @@ -23,6 +22,8 @@ from zai.types.sensitive_word_check import SensitiveWordCheckRequest from .transcriptions import Transcriptions +from zai.core._streaming import StreamResponse +from zai.types.audio import AudioSpeechChunk if TYPE_CHECKING: from zai._client import ZaiClient @@ -60,7 +61,7 @@ def speech( speed: float | None = 1.0, volume: float | None = 1.0, stream: bool | None = False - ) -> HttpxBinaryResponseContent: + ) -> HttpxBinaryResponseContent | StreamResponse[AudioSpeechChunk]: """ Generate speech audio from text input @@ -83,7 +84,6 @@ def speech( 'voice': voice, 'response_format': response_format, 'encode_format': encode_format, - 'sensitive_word_check': sensitive_word_check, 'request_id': request_id, 'user_id': user_id, 'speed': speed, @@ -96,6 +96,8 @@ def speech( body=maybe_transform(body, AudioSpeechParams), options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=HttpxBinaryResponseContent, + stream=stream or False, + stream_cls=StreamResponse[AudioSpeechChunk] ) def customization( diff --git a/src/zai/types/audio/__init__.py b/src/zai/types/audio/__init__.py index ad73c7b..f53dd5e 100644 --- a/src/zai/types/audio/__init__.py +++ b/src/zai/types/audio/__init__.py @@ -1,5 +1,6 @@ from .audio_customization_param import AudioCustomizationParam +from .audio_speech_chunk import AudioSpeechChunk from .audio_speech_params import AudioSpeechParams from .transcriptions_create_param import TranscriptionsParam -__all__ = ['AudioSpeechParams', 'AudioCustomizationParam', 'TranscriptionsParam'] +__all__ = ['AudioSpeechParams', 'AudioCustomizationParam', 'TranscriptionsParam', 'AudioSpeechChunk'] diff --git a/src/zai/types/audio/audio_speech_chunk.py b/src/zai/types/audio/audio_speech_chunk.py new file mode 100644 index 0000000..3d18c3b --- /dev/null +++ b/src/zai/types/audio/audio_speech_chunk.py @@ -0,0 +1,32 @@ +from typing import List, Optional, Dict, Any + +from ...core import BaseModel + +__all__ = [ + "AudioSpeechChunk", + "AudioError", + "AudioSpeechChoice", + "AudioSpeechDelta" +] + + +class AudioSpeechDelta(BaseModel): + content: Optional[str] = None + role: Optional[str] = None + + +class AudioSpeechChoice(BaseModel): + delta: AudioSpeechDelta + finish_reason: Optional[str] = None + index: int + +class AudioError(BaseModel): + code: Optional[str] = None + message: Optional[str] = None + + +class AudioSpeechChunk(BaseModel): + choices: List[AudioSpeechChoice] + request_id: Optional[str] = None + created: Optional[int] = None + error: Optional[AudioError] = None \ No newline at end of file diff --git a/src/zai/types/audio/audio_speech_params.py b/src/zai/types/audio/audio_speech_params.py index c8eb538..8843ab3 100644 --- a/src/zai/types/audio/audio_speech_params.py +++ b/src/zai/types/audio/audio_speech_params.py @@ -29,3 +29,7 @@ class AudioSpeechParams(TypedDict, total=False): sensitive_word_check: Optional[SensitiveWordCheckRequest] request_id: str user_id: str + encode_format: str + speed: float + volume: float + stream: bool diff --git a/tests/integration_tests/test_audio.py b/tests/integration_tests/test_audio.py index e9ef33c..3dbcf66 100644 --- a/tests/integration_tests/test_audio.py +++ b/tests/integration_tests/test_audio.py @@ -1,3 +1,4 @@ +import base64 import logging import logging.config from pathlib import Path @@ -5,7 +6,6 @@ import zai from zai import ZaiClient - def test_audio_speech(logging_conf): logging.config.dictConfig(logging_conf) # type: ignore client = ZaiClient() # Fill in your own API Key @@ -17,11 +17,19 @@ def test_audio_speech(logging_conf): voice='female', response_format='pcm', encode_format='base64', - stream=False, + stream=True, speed=1.0, volume=1.0, ) - response.stream_to_file(speech_file_path) + with open("output.pcm", "wb") as f: + for item in response: + choice = item.choices[0] + index = choice.index + finish_reason = choice.finish_reason + if choice.delta is None: + break + audio_delta = choice.delta.content + f.write(base64.b64decode(audio_delta)) except zai.core._errors.APIRequestFailedError as err: print(err)