Skip to content

Commit

Permalink
refactor: use http polling instead of websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethnym committed Aug 22, 2024
1 parent 13ae315 commit 58b4720
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 44 deletions.
50 changes: 31 additions & 19 deletions fal_app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import io
import datetime
from pathlib import Path
import threading
from audiocraft.data.audio import audio_write
import fal
from fastapi import WebSocket
from fastapi import Response, status
import torch

from prompts import PROMPTS

DATA_DIR = Path("/data/audio")

PROMPTS = [
"Create a futuristic lo-fi beat that blends modern electronic elements with synthwave influences. Incorporate smooth, atmospheric synths and gentle, relaxing rhythms to evoke a sense of a serene, neon-lit future. Ensure the track is continuous with no background noise or interruptions, maintaining a calm and tranquil atmosphere throughout while adding a touch of retro-futuristic vibes.",
"gentle lo-fi beat with a smooth, mellow piano melody in the background. Ensure there are no background noises or interruptions, maintaining a continuous and seamless flow throughout the track. The beat should be relaxing and tranquil, perfect for a calm and reflective atmosphere.",
"Create an earthy lo-fi beat that evokes a natural, grounded atmosphere. Incorporate organic sounds like soft percussion, rustling leaves, and gentle acoustic instruments. The track should have a warm, soothing rhythm with a continuous flow and no background noise or interruptions, maintaining a calm and reflective ambiance throughout.",
"Create a soothing lo-fi beat featuring gentle, melodic guitar riffs. The guitar should be the focal point, supported by subtle, ambient electronic elements and a smooth, relaxed rhythm. Ensure the track is continuous with no background noise or interruptions, maintaining a warm and mellow atmosphere throughout.",
"Create an ambient lo-fi beat with a tranquil and ethereal atmosphere. Use soft, atmospheric pads, gentle melodies, and minimalistic percussion to evoke a sense of calm and serenity. Ensure the track is continuous with no background noise or interruptions, maintaining a soothing and immersive ambiance throughout.",
]


class InfinifiFalApp(fal.App, keep_alive=300):
machine_type = "GPU-A6000"
Expand All @@ -17,8 +24,11 @@ class InfinifiFalApp(fal.App, keep_alive=300):
"audiocraft==1.3.0",
"torchaudio==2.1.0",
"websockets==11.0.3",
"numpy==1.26.4",
]

__is_generating = False

def setup(self):
import torchaudio
from audiocraft.models.musicgen import MusicGen
Expand All @@ -28,22 +38,26 @@ def setup(self):

@fal.endpoint("/generate")
def run(self):
wav = self.model.generate(PROMPTS)
if self.__is_generating:
return Response(status_code=status.HTTP_409_CONFLICT)
threading.Thread(target=self.__generate_audio).start()

serialized = []
for one_wav in wav:
buf = io.BytesIO()
torch.save(one_wav.cpu(), buf)
serialized.append(buf.getvalue())
@fal.endpoint("/clips/{index}")
def get_clips(self, index):
if self.__is_generating:
return Response(status_code=status.HTTP_404_NOT_FOUND)

return serialized
path = DATA_DIR.joinpath(f"{index}")
with open(path.with_suffix(".mp3"), "rb") as f:
data = f.read()
return Response(content=data)

@fal.endpoint("/ws")
async def run_ws(self, ws: WebSocket):
await ws.accept()
def __generate_audio(self):
self.__is_generating = True

wav = self.model.generate(PROMPTS)
print(f"[INFO] {datetime.datetime.now()}: generating audio...")

wav = self.model.generate(PROMPTS)
for i, one_wav in enumerate(wav):
path = DATA_DIR.joinpath(f"{i}")
audio_write(
Expand All @@ -53,9 +67,7 @@ async def run_ws(self, ws: WebSocket):
format="mp3",
strategy="loudness",
loudness_compressor=True,
make_parent_dir=True,
)
with open(path, "rb") as f:
data = f.read()
await ws.send_bytes(data)

await ws.close()
self.__is_generating = False
71 changes: 46 additions & 25 deletions server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import threading
import os
from time import sleep
import requests

import websocket
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from logger import log_info, log_warn
Expand All @@ -15,18 +17,18 @@
t = None
# websocket connection to the inference server
ws = None
ws_url = ""
inference_url = ""
ws_connection_manager = WebSocketConnectionManager()
active_listeners = set()


@asynccontextmanager
async def lifespan(app: FastAPI):
global ws, ws_url
global ws, inference_url

ws_url = os.environ.get("INFERENCE_SERVER_WS_URL")
if not ws_url:
ws_url = "ws://localhost:8001"
inference_url = os.environ.get("INFERENCE_SERVER_URL")
if not inference_url:
inference_url = "ws://localhost:8001"

advance()

Expand All @@ -39,7 +41,7 @@ async def lifespan(app: FastAPI):


def generate_new_audio():
if not ws_url:
if not inference_url:
return

global current_index
Expand All @@ -52,31 +54,50 @@ def generate_new_audio():
else:
return

log_info("generating new audio...")
log_info("requesting new audio...")

try:
ws = websocket.create_connection(ws_url)
print(f"{inference_url}/generate")
requests.post(f"{inference_url}/generate")
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return

ws.send("generate")
is_available = False
while not is_available:
try:
res = requests.post(f"{inference_url}/clips/0", stream=True)
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
return

wavs = []
for i in range(5):
raw = ws.recv()
if isinstance(raw, str):
continue
wavs.append(raw)
if res.status_code != status.HTTP_200_OK:
print("still generating...")
sleep(5)
continue

for i, wav in enumerate(wavs):
with open(f"{i + offset}.mp3", "wb") as f:
f.write(wav)
print("inference complete! downloading new clips")

log_info("audio generated.")
is_available = True
with open(f"{offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)

ws.close()
except:
log_warn(
"inference server potentially unreachable. recycling cached audio for now."
)
for i in range(4):
res = requests.post(f"{inference_url}/clips/{i + 1}", stream=True)

if res.status_code != status.HTTP_200_OK:
continue

with open(f"{i + 1 + offset}.mp3", "wb") as f:
for chunk in res.iter_content(chunk_size=128):
f.write(chunk)

log_info("audio generated.")


def advance():
Expand Down

0 comments on commit 58b4720

Please sign in to comment.