2
2
# pylint: disable=too-few-public-methods
3
3
4
4
import os
5
- from typing import Union
5
+ import uuid
6
+ from typing import Union , Optional , Dict , Any
6
7
7
8
import uvicorn
8
9
from fastapi import FastAPI , APIRouter
9
10
from fastapi .encoders import jsonable_encoder
11
+ from fastapi .responses import RedirectResponse
10
12
from pydantic import BaseModel
11
13
12
14
from .rp_handler import is_generator
@@ -47,14 +49,39 @@ class TestJob(BaseModel):
47
49
''' Represents a test job.
48
50
input can be any type of data.
49
51
'''
50
- id : str = "test_job"
51
- input : Union [dict , list , str , int , float , bool ]
52
+ id : Optional [str ]
53
+ input : Optional [Union [dict , list , str , int , float , bool ]]
54
+
55
+
56
+ class DefaultInput (BaseModel ):
57
+ """ Represents a test input. """
58
+ input : Dict [str , Any ]
59
+
60
+
61
+ # ------------------------------ Output Objects ------------------------------ #
62
+ class JobOutput (BaseModel ):
63
+ ''' Represents the output of a job. '''
64
+ id : str
65
+ status : str
66
+ output : Optional [Union [dict , list , str , int , float , bool ]]
67
+ error : Optional [str ]
68
+
69
+
70
+ class StreamOutput (BaseModel ):
71
+ """ Stream representation of a job. """
72
+ id : str
73
+ status : str = "IN_PROGRESS"
74
+ stream : Optional [Union [dict , list , str , int , float , bool ]]
75
+ error : Optional [str ]
52
76
53
77
78
+ # ---------------------------------------------------------------------------- #
79
+ # API Worker #
80
+ # ---------------------------------------------------------------------------- #
54
81
class WorkerAPI :
55
82
''' Used to launch the FastAPI web server when the worker is running in API mode. '''
56
83
57
- def __init__ (self , handler = None ):
84
+ def __init__ (self , config : Dict [ str , Any ] ):
58
85
'''
59
86
Initializes the WorkerAPI class.
60
87
1. Starts the heartbeat thread.
@@ -64,23 +91,50 @@ def __init__(self, handler=None):
64
91
# Start the heartbeat thread.
65
92
heartbeat .start_ping ()
66
93
67
- # Set the handler for processing jobs.
68
- self .config = {"handler" : handler }
94
+ self .config = config
69
95
70
96
# Initialize the FastAPI web server.
71
97
self .rp_app = FastAPI (
72
98
title = "RunPod | Test Worker | API" ,
73
99
description = DESCRIPTION ,
74
100
version = runpod_version ,
101
+ docs_url = "/"
75
102
)
76
103
77
104
# Create an APIRouter and add the route for processing jobs.
78
105
api_router = APIRouter ()
79
106
80
- if RUNPOD_ENDPOINT_ID :
81
- api_router .add_api_route (f"/{ RUNPOD_ENDPOINT_ID } /realtime" , self ._run , methods = ["POST" ])
107
+ # Docs Redirect /docs -> /
108
+ api_router .add_api_route (
109
+ "/docs" , lambda : RedirectResponse (url = "/" ),
110
+ include_in_schema = False
111
+ )
82
112
83
- api_router .add_api_route ("/runsync" , self ._debug_run , methods = ["POST" ])
113
+ if RUNPOD_ENDPOINT_ID :
114
+ api_router .add_api_route (f"/{ RUNPOD_ENDPOINT_ID } /realtime" ,
115
+ self ._realtime , methods = ["POST" ])
116
+
117
+ # Simulation endpoints.
118
+ api_router .add_api_route (
119
+ "/run" , self ._sim_run , methods = ["POST" ], response_model_exclude_none = True ,
120
+ summary = "Simulate run behavior." ,
121
+ description = "Returns job ID to be used with `/stream` and `/status` endpoints."
122
+ )
123
+ api_router .add_api_route (
124
+ "/runsync" , self ._sim_runsync , methods = ["POST" ], response_model_exclude_none = True ,
125
+ summary = "Simulate runsync behavior." ,
126
+ description = "Returns job output directly when called."
127
+ )
128
+ api_router .add_api_route (
129
+ "/stream/{job_id}" , self ._sim_stream , methods = ["POST" ],
130
+ response_model_exclude_none = True , summary = "Simulate stream behavior." ,
131
+ description = "Aggregates the output of the job and returns it when the job is complete."
132
+ )
133
+ api_router .add_api_route (
134
+ "/status/{job_id}" , self ._sim_status , methods = ["POST" ],
135
+ response_model_exclude_none = True , summary = "Simulate status behavior." ,
136
+ description = "Returns the output of the job when the job is complete."
137
+ )
84
138
85
139
# Include the APIRouter in the FastAPI application.
86
140
self .rp_app .include_router (api_router )
@@ -96,47 +150,111 @@ def start_uvicorn(self, api_host='localhost', api_port=8000, api_concurrency=1):
96
150
access_log = False
97
151
)
98
152
99
- async def _run (self , job : Job ):
153
+ # ----------------------------- Realtime Endpoint ---------------------------- #
154
+ async def _realtime (self , job : Job ):
100
155
'''
101
156
Performs model inference on the input data using the provided handler.
102
157
If handler is not provided, returns an error message.
103
158
'''
104
- if self .config ["handler" ] is None :
105
- return {"error" : "Handler not provided" }
106
-
107
- # Set the current job ID.
108
159
job_list .add_job (job .id )
109
160
110
- # Process the job using the provided handler.
161
+ # Process the job using the provided handler, passing in the job input .
111
162
job_results = await run_job (self .config ["handler" ], job .__dict__ )
112
163
113
- # Reset the job ID.
114
164
job_list .remove_job (job .id )
115
165
116
166
# Return the results of the job processing.
117
167
return jsonable_encoder (job_results )
118
168
119
- async def _debug_run (self , job : TestJob ):
120
- '''
121
- Performs model inference on the input data using the provided handler.
122
- '''
123
- if self .config ["handler" ] is None :
124
- return {"error" : "Handler not provided" }
169
+ # ---------------------------------------------------------------------------- #
170
+ # Simulation Endpoints #
171
+ # ---------------------------------------------------------------------------- #
125
172
126
- # Set the current job ID.
127
- job_list .add_job (job .id )
173
+ # ------------------------------------ run ----------------------------------- #
174
+ async def _sim_run (self , job_input : DefaultInput ) -> JobOutput :
175
+ """ Development endpoint to simulate run behavior. """
176
+ assigned_job_id = f"test-{ uuid .uuid4 ()} "
177
+ job_list .add_job (assigned_job_id , job_input .input )
178
+ return jsonable_encoder ({"id" : assigned_job_id , "status" : "IN_PROGRESS" })
179
+
180
+ # ---------------------------------- runsync --------------------------------- #
181
+ async def _sim_runsync (self , job_input : DefaultInput ) -> JobOutput :
182
+ """ Development endpoint to simulate runsync behavior. """
183
+ assigned_job_id = f"test-{ uuid .uuid4 ()} "
184
+ job = TestJob (id = assigned_job_id , input = job_input .input )
128
185
129
186
if is_generator (self .config ["handler" ]):
130
187
generator_output = run_job_generator (self .config ["handler" ], job .__dict__ )
131
- job_results = {"output" : []}
188
+ job_output = {"output" : []}
132
189
async for stream_output in generator_output :
133
- job_results [ " output" ].append (stream_output ["output" ])
190
+ job_output [ ' output' ].append (stream_output ["output" ])
134
191
else :
135
- job_results = await run_job (self .config ["handler" ], job .__dict__ )
192
+ job_output = await run_job (self .config ["handler" ], job .__dict__ )
193
+
194
+ return jsonable_encoder ({
195
+ "id" : job .id ,
196
+ "status" : "COMPLETED" ,
197
+ "output" : job_output ['output' ]
198
+ })
199
+
200
+ # ---------------------------------- stream ---------------------------------- #
201
+ async def _sim_stream (self , job_id : str ) -> StreamOutput :
202
+ """ Development endpoint to simulate stream behavior. """
203
+ job_input = job_list .get_job_input (job_id )
204
+ if job_input is None :
205
+ return jsonable_encoder ({
206
+ "id" : job_id ,
207
+ "status" : "FAILED" ,
208
+ "error" : "Job ID not found"
209
+ })
210
+
211
+ job = TestJob (id = job_id , input = job_input )
136
212
137
- job_results ["id" ] = job .id
213
+ if is_generator (self .config ["handler" ]):
214
+ generator_output = run_job_generator (self .config ["handler" ], job .__dict__ )
215
+ stream_accumulator = []
216
+ async for stream_output in generator_output :
217
+ stream_accumulator .append ({"output" : stream_output ["output" ]})
218
+ else :
219
+ return jsonable_encoder ({
220
+ "id" : job_id ,
221
+ "status" : "FAILED" ,
222
+ "error" : "Stream not supported, handler must be a generator."
223
+ })
138
224
139
- # Reset the job ID.
140
225
job_list .remove_job (job .id )
141
226
142
- return jsonable_encoder (job_results )
227
+ return jsonable_encoder ({
228
+ "id" : job_id ,
229
+ "status" : "COMPLETED" ,
230
+ "stream" : stream_accumulator
231
+ })
232
+
233
+ # ---------------------------------- status ---------------------------------- #
234
+ async def _sim_status (self , job_id : str ) -> JobOutput :
235
+ """ Development endpoint to simulate status behavior. """
236
+ job_input = job_list .get_job_input (job_id )
237
+ if job_input is None :
238
+ return jsonable_encoder ({
239
+ "id" : job_id ,
240
+ "status" : "FAILED" ,
241
+ "error" : "Job ID not found"
242
+ })
243
+
244
+ job = TestJob (id = job_id , input = job_input )
245
+
246
+ if is_generator (self .config ["handler" ]):
247
+ generator_output = run_job_generator (self .config ["handler" ], job .__dict__ )
248
+ job_output = {"output" : []}
249
+ async for stream_output in generator_output :
250
+ job_output ['output' ].append (stream_output ["output" ])
251
+ else :
252
+ job_output = await run_job (self .config ["handler" ], job .__dict__ )
253
+
254
+ job_list .remove_job (job .id )
255
+
256
+ return jsonable_encoder ({
257
+ "id" : job_id ,
258
+ "status" : "COMPLETED" ,
259
+ "output" : job_output ['output' ]
260
+ })
0 commit comments