diff --git a/inference_server.py b/inference_server.py index 04bbdb1..688bd7d 100644 --- a/inference_server.py +++ b/inference_server.py @@ -2,6 +2,7 @@ from websockets.server import serve from generate import generate +from logger import log_info async def handler(websocket): @@ -9,11 +10,11 @@ async def handler(websocket): if message != "generate": continue - print("generating new audio clips...") + log_info("generating new audio clips...") generate() - print("audio generated") + log_info("audio generated") for i in range(5): with open(f"{i}.mp3", "rb") as f: diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..a125c68 --- /dev/null +++ b/logger.py @@ -0,0 +1,9 @@ +import datetime + + +def log_info(message: str): + print(f"[INFO] {datetime.datetime.now()}: {message}") + + +def log_warn(message: str): + print(f"[WARN] {datetime.datetime.now()}: {message}") diff --git a/server.py b/server.py index 96a75fc..d6d7f70 100644 --- a/server.py +++ b/server.py @@ -6,6 +6,7 @@ from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from logger import log_info, log_warn # the index of the current audio track from 0 to 9 current_index = -1 @@ -48,27 +49,31 @@ def generate_new_audio(): else: return - print("generating new audio...") + log_info("generating new audio...") - ws = websocket.create_connection(ws_url) - print(f"websocket connected to {ws_url}") + try: + ws = websocket.create_connection(ws_url) - ws.send("generate") + ws.send("generate") - wavs = [] - for i in range(5): - raw = ws.recv() - if isinstance(raw, str): - continue - wavs.append(raw) + wavs = [] + for i in range(5): + raw = ws.recv() + if isinstance(raw, str): + continue + wavs.append(raw) - for i, wav in enumerate(wavs): - with open(f"{i + offset}.mp3", "wb") as f: - f.write(wav) + for i, wav in enumerate(wavs): + with open(f"{i + offset}.mp3", "wb") as f: + f.write(wav) - print("audio generated.") + log_info("audio generated.") - ws.close() + ws.close() + except: + log_warn( + "inference server potentially unreachable. recycling cached audio for now." + ) def advance(): @@ -79,7 +84,7 @@ def advance(): else: current_index = current_index + 1 - # threading.Thread(target=generate_new_audio).start() + threading.Thread(target=generate_new_audio).start() t = threading.Timer(60, advance) t.start()