From e1866e08f5d485ecc933f8f60133f761b479c335 Mon Sep 17 00:00:00 2001 From: Kenneth Date: Mon, 22 Jul 2024 22:24:39 +0100 Subject: [PATCH] feat: create inference ws server --- inference_server.py | 30 ++++++++++++++++++++++++++++++ requirements-inference.txt | 3 +++ requirements-server.txt | 2 ++ server.py | 35 ++++++++++++++++++++++++++++++++--- 4 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 inference_server.py create mode 100644 requirements-inference.txt create mode 100644 requirements-server.txt diff --git a/inference_server.py b/inference_server.py new file mode 100644 index 0000000..924e0dc --- /dev/null +++ b/inference_server.py @@ -0,0 +1,30 @@ +import asyncio +from websockets.server import serve + +# from generate import generate + + +async def handler(websocket): + async for message in websocket: + if message != "generate": + continue + + print("generating new audio clips...") + + # generate() + + print("audio generated") + + for i in range(5): + with open(f"{i + 5}.mp3", "rb") as f: + data = f.read() + await websocket.send(data) + + +async def main(): + async with serve(handler, "", 8001): + await asyncio.Future() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/requirements-inference.txt b/requirements-inference.txt new file mode 100644 index 0000000..6915edc --- /dev/null +++ b/requirements-inference.txt @@ -0,0 +1,3 @@ +audiocraft==1.3.0 +torchaudio==2.1.0 +websockets==11.0.3 diff --git a/requirements-server.txt b/requirements-server.txt new file mode 100644 index 0000000..260e9f6 --- /dev/null +++ b/requirements-server.txt @@ -0,0 +1,2 @@ +fastapi==0.111.1 +websocket_client==1.8.0 diff --git a/server.py b/server.py index bb9d16d..12222ce 100644 --- a/server.py +++ b/server.py @@ -1,16 +1,18 @@ import threading +import os -from generate import generate +import websocket from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles -from audiocraft.data.audio import audio_write # the index of the current audio track from 0 to 9 current_index = -1 # the timer that periodically advances the current audio track t = None +# websocket connection to the inference server +ws = None prompts = [ "gentle, calming lo-fi beats that helps with studying and focusing", @@ -23,13 +25,29 @@ @asynccontextmanager async def lifespan(app: FastAPI): + global ws + + url = os.environ.get("INFERENCE_SERVER_WS_URL") + if not url: + url = "ws://localhost:8001" + + ws = websocket.create_connection(url) + print(f"websocket connected to {url}") + advance() + yield + + if ws: + ws.close() if t: t.cancel() def generate_new_audio(): + if not ws: + return + global current_index offset = 0 @@ -42,7 +60,18 @@ def generate_new_audio(): print("generating new audio...") - generate(offset) + ws.send("generate") + + wavs = [] + for i in range(5): + raw = ws.recv() + if isinstance(raw, str): + continue + wavs.append(raw) + + for i, wav in enumerate(wavs): + with open(f"{i + offset}.mp3", "wb") as f: + f.write(wav) print("audio generated.")