Skip to content

Commit b2027a8

Browse files
Merge pull request #253 from runpod/local-streaming
Local streaming
2 parents 1407756 + 0b572c8 commit b2027a8

File tree

8 files changed

+410
-76
lines changed

8 files changed

+410
-76
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# Change Log
22

3+
## Release 1.4.1 (12/13/23)
4+
5+
### Added
6+
7+
- Local test API server includes simulated endpoints that mimic the behavior of `run`, `runsync`, `stream`, and `status`.
8+
- Internal job tracker can be used to track job inputs.
9+
10+
---
11+
312
## Release 1.4.0 (12/4/23)
413

514
### Changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
""" Simple Handler
2+
3+
To setup a local API server, run the following command:
4+
python simple_handler.py --rp_serve_api
5+
"""
6+
7+
import runpod
8+
9+
10+
def handler(job):
11+
""" Simple handler """
12+
job_input = job["input"]
13+
14+
return f"Hello {job_input['name']}!"
15+
16+
17+
runpod.serverless.start({"handler": handler})

runpod/serverless/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def start(config: Dict[str, Any]):
126126

127127
if config["rp_args"]["rp_serve_api"]:
128128
print("Starting API server.")
129-
api_server = rp_fastapi.WorkerAPI()
130-
api_server.config = config
129+
api_server = rp_fastapi.WorkerAPI(config)
131130

132131
api_server.start_uvicorn(
133132
api_host=config['rp_args']['rp_api_host'],
@@ -137,8 +136,7 @@ def start(config: Dict[str, Any]):
137136

138137
elif realtime_port:
139138
print("Starting API server for realtime.")
140-
api_server = rp_fastapi.WorkerAPI()
141-
api_server.config = config
139+
api_server = rp_fastapi.WorkerAPI(config)
142140

143141
api_server.start_uvicorn(
144142
api_host='0.0.0.0',

runpod/serverless/modules/rp_fastapi.py

Lines changed: 148 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
# pylint: disable=too-few-public-methods
33

44
import os
5-
from typing import Union
5+
import uuid
6+
from typing import Union, Optional, Dict, Any
67

78
import uvicorn
89
from fastapi import FastAPI, APIRouter
910
from fastapi.encoders import jsonable_encoder
11+
from fastapi.responses import RedirectResponse
1012
from pydantic import BaseModel
1113

1214
from .rp_handler import is_generator
@@ -47,14 +49,39 @@ class TestJob(BaseModel):
4749
''' Represents a test job.
4850
input can be any type of data.
4951
'''
50-
id: str = "test_job"
51-
input: Union[dict, list, str, int, float, bool]
52+
id: Optional[str]
53+
input: Optional[Union[dict, list, str, int, float, bool]]
54+
55+
56+
class DefaultInput(BaseModel):
57+
""" Represents a test input. """
58+
input: Dict[str, Any]
59+
60+
61+
# ------------------------------ Output Objects ------------------------------ #
62+
class JobOutput(BaseModel):
63+
''' Represents the output of a job. '''
64+
id: str
65+
status: str
66+
output: Optional[Union[dict, list, str, int, float, bool]]
67+
error: Optional[str]
68+
69+
70+
class StreamOutput(BaseModel):
71+
""" Stream representation of a job. """
72+
id: str
73+
status: str = "IN_PROGRESS"
74+
stream: Optional[Union[dict, list, str, int, float, bool]]
75+
error: Optional[str]
5276

5377

78+
# ---------------------------------------------------------------------------- #
79+
# API Worker #
80+
# ---------------------------------------------------------------------------- #
5481
class WorkerAPI:
5582
''' Used to launch the FastAPI web server when the worker is running in API mode. '''
5683

57-
def __init__(self, handler=None):
84+
def __init__(self, config: Dict[str, Any]):
5885
'''
5986
Initializes the WorkerAPI class.
6087
1. Starts the heartbeat thread.
@@ -64,23 +91,50 @@ def __init__(self, handler=None):
6491
# Start the heartbeat thread.
6592
heartbeat.start_ping()
6693

67-
# Set the handler for processing jobs.
68-
self.config = {"handler": handler}
94+
self.config = config
6995

7096
# Initialize the FastAPI web server.
7197
self.rp_app = FastAPI(
7298
title="RunPod | Test Worker | API",
7399
description=DESCRIPTION,
74100
version=runpod_version,
101+
docs_url="/"
75102
)
76103

77104
# Create an APIRouter and add the route for processing jobs.
78105
api_router = APIRouter()
79106

80-
if RUNPOD_ENDPOINT_ID:
81-
api_router.add_api_route(f"/{RUNPOD_ENDPOINT_ID}/realtime", self._run, methods=["POST"])
107+
# Docs Redirect /docs -> /
108+
api_router.add_api_route(
109+
"/docs", lambda: RedirectResponse(url="/"),
110+
include_in_schema=False
111+
)
82112

83-
api_router.add_api_route("/runsync", self._debug_run, methods=["POST"])
113+
if RUNPOD_ENDPOINT_ID:
114+
api_router.add_api_route(f"/{RUNPOD_ENDPOINT_ID}/realtime",
115+
self._realtime, methods=["POST"])
116+
117+
# Simulation endpoints.
118+
api_router.add_api_route(
119+
"/run", self._sim_run, methods=["POST"], response_model_exclude_none=True,
120+
summary="Simulate run behavior.",
121+
description="Returns job ID to be used with `/stream` and `/status` endpoints."
122+
)
123+
api_router.add_api_route(
124+
"/runsync", self._sim_runsync, methods=["POST"], response_model_exclude_none=True,
125+
summary="Simulate runsync behavior.",
126+
description="Returns job output directly when called."
127+
)
128+
api_router.add_api_route(
129+
"/stream/{job_id}", self._sim_stream, methods=["POST"],
130+
response_model_exclude_none=True, summary="Simulate stream behavior.",
131+
description="Aggregates the output of the job and returns it when the job is complete."
132+
)
133+
api_router.add_api_route(
134+
"/status/{job_id}", self._sim_status, methods=["POST"],
135+
response_model_exclude_none=True, summary="Simulate status behavior.",
136+
description="Returns the output of the job when the job is complete."
137+
)
84138

85139
# Include the APIRouter in the FastAPI application.
86140
self.rp_app.include_router(api_router)
@@ -96,47 +150,111 @@ def start_uvicorn(self, api_host='localhost', api_port=8000, api_concurrency=1):
96150
access_log=False
97151
)
98152

99-
async def _run(self, job: Job):
153+
# ----------------------------- Realtime Endpoint ---------------------------- #
154+
async def _realtime(self, job: Job):
100155
'''
101156
Performs model inference on the input data using the provided handler.
102157
If handler is not provided, returns an error message.
103158
'''
104-
if self.config["handler"] is None:
105-
return {"error": "Handler not provided"}
106-
107-
# Set the current job ID.
108159
job_list.add_job(job.id)
109160

110-
# Process the job using the provided handler.
161+
# Process the job using the provided handler, passing in the job input.
111162
job_results = await run_job(self.config["handler"], job.__dict__)
112163

113-
# Reset the job ID.
114164
job_list.remove_job(job.id)
115165

116166
# Return the results of the job processing.
117167
return jsonable_encoder(job_results)
118168

119-
async def _debug_run(self, job: TestJob):
120-
'''
121-
Performs model inference on the input data using the provided handler.
122-
'''
123-
if self.config["handler"] is None:
124-
return {"error": "Handler not provided"}
169+
# ---------------------------------------------------------------------------- #
170+
# Simulation Endpoints #
171+
# ---------------------------------------------------------------------------- #
125172

126-
# Set the current job ID.
127-
job_list.add_job(job.id)
173+
# ------------------------------------ run ----------------------------------- #
174+
async def _sim_run(self, job_input: DefaultInput) -> JobOutput:
175+
""" Development endpoint to simulate run behavior. """
176+
assigned_job_id = f"test-{uuid.uuid4()}"
177+
job_list.add_job(assigned_job_id, job_input.input)
178+
return jsonable_encoder({"id": assigned_job_id, "status": "IN_PROGRESS"})
179+
180+
# ---------------------------------- runsync --------------------------------- #
181+
async def _sim_runsync(self, job_input: DefaultInput) -> JobOutput:
182+
""" Development endpoint to simulate runsync behavior. """
183+
assigned_job_id = f"test-{uuid.uuid4()}"
184+
job = TestJob(id=assigned_job_id, input=job_input.input)
128185

129186
if is_generator(self.config["handler"]):
130187
generator_output = run_job_generator(self.config["handler"], job.__dict__)
131-
job_results = {"output": []}
188+
job_output = {"output": []}
132189
async for stream_output in generator_output:
133-
job_results["output"].append(stream_output["output"])
190+
job_output['output'].append(stream_output["output"])
134191
else:
135-
job_results = await run_job(self.config["handler"], job.__dict__)
192+
job_output = await run_job(self.config["handler"], job.__dict__)
193+
194+
return jsonable_encoder({
195+
"id": job.id,
196+
"status": "COMPLETED",
197+
"output": job_output['output']
198+
})
199+
200+
# ---------------------------------- stream ---------------------------------- #
201+
async def _sim_stream(self, job_id: str) -> StreamOutput:
202+
""" Development endpoint to simulate stream behavior. """
203+
job_input = job_list.get_job_input(job_id)
204+
if job_input is None:
205+
return jsonable_encoder({
206+
"id": job_id,
207+
"status": "FAILED",
208+
"error": "Job ID not found"
209+
})
210+
211+
job = TestJob(id=job_id, input=job_input)
136212

137-
job_results["id"] = job.id
213+
if is_generator(self.config["handler"]):
214+
generator_output = run_job_generator(self.config["handler"], job.__dict__)
215+
stream_accumulator = []
216+
async for stream_output in generator_output:
217+
stream_accumulator.append({"output": stream_output["output"]})
218+
else:
219+
return jsonable_encoder({
220+
"id": job_id,
221+
"status": "FAILED",
222+
"error": "Stream not supported, handler must be a generator."
223+
})
138224

139-
# Reset the job ID.
140225
job_list.remove_job(job.id)
141226

142-
return jsonable_encoder(job_results)
227+
return jsonable_encoder({
228+
"id": job_id,
229+
"status": "COMPLETED",
230+
"stream": stream_accumulator
231+
})
232+
233+
# ---------------------------------- status ---------------------------------- #
234+
async def _sim_status(self, job_id: str) -> JobOutput:
235+
""" Development endpoint to simulate status behavior. """
236+
job_input = job_list.get_job_input(job_id)
237+
if job_input is None:
238+
return jsonable_encoder({
239+
"id": job_id,
240+
"status": "FAILED",
241+
"error": "Job ID not found"
242+
})
243+
244+
job = TestJob(id=job_id, input=job_input)
245+
246+
if is_generator(self.config["handler"]):
247+
generator_output = run_job_generator(self.config["handler"], job.__dict__)
248+
job_output = {"output": []}
249+
async for stream_output in generator_output:
250+
job_output['output'].append(stream_output["output"])
251+
else:
252+
job_output = await run_job(self.config["handler"], job.__dict__)
253+
254+
job_list.remove_job(job.id)
255+
256+
return jsonable_encoder({
257+
"id": job_id,
258+
"status": "COMPLETED",
259+
"output": job_output['output']
260+
})

runpod/serverless/modules/rp_job.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# pylint: disable=too-many-branches
55

66
import inspect
7-
from typing import Any, Callable, Dict, Generator, Optional, Union
7+
from typing import Any, Callable, Dict, Optional, Union, AsyncGenerator
88

99
import os
1010
import json
@@ -179,9 +179,9 @@ async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
179179

180180
async def run_job_generator(
181181
handler: Callable,
182-
job: Dict[str, Any]) -> Generator[Dict[str, Union[str, Any]], None, None]:
182+
job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Union[str, Any]], None]:
183183
'''
184-
Run generator job.
184+
Run generator job used to stream output.
185185
Yields output partials from the generator.
186186
'''
187187
try:

runpod/serverless/modules/worker_state.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import uuid
77
import time
8+
from typing import Optional, Dict, Any, Union
89

910
REF_COUNT_ZERO = time.perf_counter() # Used for benchmarking with the debugger.
1011

@@ -22,6 +23,25 @@ def get_auth_header():
2223
return {"Authorization": f"{os.environ.get('RUNPOD_AI_API_KEY')}"}
2324

2425

26+
# ------------------------------- Job Tracking ------------------------------- #
27+
class Job:
28+
""" Represents a job. """
29+
30+
def __init__(self, job_id: str, job_input: Optional[Dict[str, Any]] = None) -> None:
31+
self.job_id = job_id
32+
self.job_input = job_input
33+
34+
def __eq__(self, other: object) -> bool:
35+
if isinstance(other, Job):
36+
return self.job_id == other.job_id
37+
return False
38+
39+
def __hash__(self) -> int:
40+
return hash(self.job_id)
41+
42+
def __str__(self) -> str:
43+
return self.job_id
44+
2545

2646
class Jobs:
2747
''' Track the state of current jobs.'''
@@ -35,23 +55,31 @@ def __new__(cls):
3555
Jobs._instance.jobs = set()
3656
return Jobs._instance
3757

38-
def add_job(self, job_id):
58+
def add_job(self, job_id, job_input=None):
3959
'''
4060
Adds a job to the list of jobs.
4161
'''
42-
self.jobs.add(job_id)
62+
self.jobs.add(Job(job_id, job_input))
4363

4464
def remove_job(self, job_id):
4565
'''
4666
Removes a job from the list of jobs.
4767
'''
48-
self.jobs.remove(job_id)
68+
self.jobs.remove(Job(job_id))
69+
70+
def get_job_input(self, job_id) -> Optional[Union[dict, list, str, int, float, bool]]:
71+
'''
72+
Returns the job with the given id.
73+
Used within rp_fastapi.py for local testing.
74+
'''
75+
for job in self.jobs:
76+
if job.job_id == job_id:
77+
return job.job_input
78+
79+
return None
4980

5081
def get_job_list(self):
5182
'''
5283
Returns the list of jobs as a string.
5384
'''
54-
if len(self.jobs) == 0:
55-
return None
56-
57-
return ','.join(list(self.jobs))
85+
return ','.join(str(job) for job in self.jobs) if self.jobs else None

0 commit comments

Comments
 (0)