forked from runpod-workers/cog-worker
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
27 lines (25 loc) · 951 Bytes
/
predict.py
File metadata and controls
27 lines (25 loc) · 951 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from cog import BasePredictor, Input, Path
import requests
import base64
from typing import Any
class Predictor(BasePredictor):
def setup(self):
self.tts_url = "http://localhost:8880/v1/audio/speech"
def predict(self, text: str = Input(description="Text for TTS"),
voice: str = Input(default="af_bella", description="Voice ID"),
speed: float = Input(default=1.0, description="Speed")) -> Path:
payload = {
"model": "kokoro",
"input": text,
"voice": voice,
"speed": speed,
"response_format": "mp3"
}
response = requests.post(self.tts_url, json=payload)
response.raise_for_status()
audio_bytes = response.content
# Save to temp file
output_path = Path(self.output_path / "audio.mp3")
with open(output_path, "wb") as f:
f.write(audio_bytes)
return output_path