Skip to content

Commit

Permalink
feat: create inference ws server
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethnym committed Jul 22, 2024
1 parent e4e4fc5 commit e1866e0
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 3 deletions.
30 changes: 30 additions & 0 deletions inference_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio
from websockets.server import serve

# from generate import generate


async def handler(websocket):
async for message in websocket:
if message != "generate":
continue

print("generating new audio clips...")

# generate()

print("audio generated")

for i in range(5):
with open(f"{i + 5}.mp3", "rb") as f:
data = f.read()
await websocket.send(data)


async def main():
async with serve(handler, "", 8001):
await asyncio.Future()


if __name__ == "__main__":
asyncio.run(main())
3 changes: 3 additions & 0 deletions requirements-inference.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
audiocraft==1.3.0
torchaudio==2.1.0
websockets==11.0.3
2 changes: 2 additions & 0 deletions requirements-server.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
fastapi==0.111.1
websocket_client==1.8.0
35 changes: 32 additions & 3 deletions server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import threading
import os

from generate import generate
import websocket
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from audiocraft.data.audio import audio_write

# the index of the current audio track from 0 to 9
current_index = -1
# the timer that periodically advances the current audio track
t = None
# websocket connection to the inference server
ws = None

prompts = [
"gentle, calming lo-fi beats that helps with studying and focusing",
Expand All @@ -23,13 +25,29 @@

@asynccontextmanager
async def lifespan(app: FastAPI):
global ws

url = os.environ.get("INFERENCE_SERVER_WS_URL")
if not url:
url = "ws://localhost:8001"

ws = websocket.create_connection(url)
print(f"websocket connected to {url}")

advance()

yield

if ws:
ws.close()
if t:
t.cancel()


def generate_new_audio():
if not ws:
return

global current_index

offset = 0
Expand All @@ -42,7 +60,18 @@ def generate_new_audio():

print("generating new audio...")

generate(offset)
ws.send("generate")

wavs = []
for i in range(5):
raw = ws.recv()
if isinstance(raw, str):
continue
wavs.append(raw)

for i, wav in enumerate(wavs):
with open(f"{i + offset}.mp3", "wb") as f:
f.write(wav)

print("audio generated.")

Expand Down

0 comments on commit e1866e0

Please sign in to comment.