Skip to content

Commit ef98b95

Browse files
Merge pull request #6 from runpod/FastAPI
Fast api
2 parents ec190ca + 7445c24 commit ef98b95

File tree

9 files changed

+85
-3
lines changed

9 files changed

+85
-3
lines changed

.github/workflows/ci_pylint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ jobs:
3434
3535
- name: Pylint Source
3636
run: |
37-
find . -type f -name '*.py' | xargs pylint
37+
find . -type f -name '*.py' | xargs pylint --extension-pkg-whitelist='pydantic'

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
env
2+
.env
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

docs/serverless/worker.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ RUNPOD_WEBHOOK_GET_JOB= # URL to get job work from
1515
RUNPOD_WEBHOOK_POST_OUTPUT= # URL to post output to
1616
RUNPOD_WEBHOOK_PING= # URL to ping
1717
RUNPOD_PING_INTERVAL= # Interval in milliseconds to ping the API (Default: 10000)
18+
19+
# Realtime
20+
RUNPOD_REALTIME_PORT= # Port to listen on for realtime connections (Default: None)
21+
RUNPOD_REALTIME_CONCURRENCY= # Number of workers to spawn (Default: 1)
1822
```
1923

2024
### Additional Variables

infer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
'''
66
# pylint: disable=unused-argument,too-few-public-methods
77

8+
import runpod
9+
810

911
def validator():
1012
'''
@@ -38,3 +40,6 @@ def run(model_inputs):
3840
"seed": "1234"
3941
}
4042
]
43+
44+
45+
runpod.serverless.start({"handler": run})

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ python-dotenv >= 0.21.0
33
requests >= 2.28.1
44
boto3 >= 1.26.15
55
aiohttp >= 3.8.3
6+
fastapi[all] >= 0.89.0

runpod/serverless/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
''' Allows serverless to recognized as a package.'''
22

3+
import os
34
import asyncio
45

56
from . import work_loop
7+
from .modules import rp_fastapi
68

79

810
def start(config):
911
'''
1012
Starts the serverless worker.
1113
'''
12-
asyncio.run(work_loop.start_worker(config))
14+
api_port = os.environ.get('RUNPOD_API_PORT', None)
15+
16+
if api_port:
17+
api_server = rp_fastapi.WorkerAPI()
18+
api_server.config = config
19+
20+
api_server.start_uvicorn(api_port)
21+
else:
22+
asyncio.run(work_loop.start_worker(config))

runpod/serverless/modules/logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def log(message, level='INFO'):
1717
Log message to stdout if RUNPOD_DEBUG is true.
1818
'''
1919
if os.environ.get('RUNPOD_DEBUG', 'true') == 'true':
20-
print(f'{level} | {message}')
20+
print(f'{level} | {message}', flush=True)
2121

2222

2323
def log_secret(secret_name, secret, level='INFO'):
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
''' Used to launch the FastAPI web server when worker is running in API mode. '''
2+
3+
import os
4+
import threading
5+
6+
import uvicorn
7+
from fastapi import FastAPI
8+
from fastapi.encoders import jsonable_encoder
9+
from pydantic import BaseModel
10+
11+
from .job import run_job
12+
from .worker_state import set_job_id
13+
from .heartbeat import start_heartbeat
14+
15+
16+
class Job(BaseModel):
17+
''' Represents a job. '''
18+
id: str
19+
input: dict
20+
21+
22+
class WorkerAPI:
23+
''' Used to launch the FastAPI web server when worker is running in API mode. '''
24+
25+
def __init__(self):
26+
'''
27+
Initializes the WorkerAPI class.
28+
1. Starts the heartbeat thread.
29+
2. Initializes the FastAPI web server.
30+
'''
31+
heartbeat_thread = threading.Thread(target=start_heartbeat)
32+
heartbeat_thread.daemon = True
33+
heartbeat_thread.start()
34+
35+
self.config = {"handler": None}
36+
self.rp_app = FastAPI()
37+
self.rp_app.add_api_route("/run", self.run, methods=["POST"])
38+
39+
def start_uvicorn(self, api_port):
40+
'''
41+
Starts the Uvicorn server.
42+
'''
43+
uvicorn.run(
44+
self.rp_app, host='0.0.0.0', port=int(api_port),
45+
workers=os.environ.get('RUNPOD_REALTIME_CONCURRENCY', 1)
46+
)
47+
48+
async def run(self, job: Job):
49+
'''
50+
Performs model inference on the input data.
51+
'''
52+
set_job_id(job.id)
53+
54+
job_results = run_job(self.config["handler"], job.__dict__)
55+
56+
set_job_id(None)
57+
58+
return jsonable_encoder(job_results)

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ install_requires =
3131
requests >= 2.28.1
3232
boto3 >= 1.26.15
3333
aiohttp >= 3.8.3
34+
fastapi[all] >= 0.89.0

0 commit comments

Comments
 (0)