Skip to content

Commit c22b523

Browse files
Merge pull request #255 from runpod/fix-pydantic
Fix pydantic
2 parents 3059ed0 + 2bb5cfa commit c22b523

File tree

3 files changed

+53
-12
lines changed

3 files changed

+53
-12
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Change Log
22

3+
## Release 1.4.2 (12/14/23)
4+
5+
### Fixed
6+
7+
- Added defaults for optional parameters in `rp_fastapi` to be compatible with pydantic.
8+
39
## Release 1.4.1 (12/13/23)
410

511
### Added

runpod/serverless/modules/rp_fastapi.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
import os
55
import uuid
6+
from dataclasses import dataclass
67
from typing import Union, Optional, Dict, Any
78

89
import uvicorn
910
from fastapi import FastAPI, APIRouter
1011
from fastapi.encoders import jsonable_encoder
1112
from fastapi.responses import RedirectResponse
12-
from pydantic import BaseModel
1313

1414
from .rp_handler import is_generator
1515
from .rp_job import run_job, run_job_generator
@@ -39,40 +39,45 @@
3939

4040

4141
# ------------------------------- Input Objects ------------------------------ #
42-
class Job(BaseModel):
42+
@dataclass
43+
class Job:
4344
''' Represents a job. '''
4445
id: str
4546
input: Union[dict, list, str, int, float, bool]
4647

4748

48-
class TestJob(BaseModel):
49+
@dataclass
50+
class TestJob:
4951
''' Represents a test job.
5052
input can be any type of data.
5153
'''
52-
id: Optional[str]
53-
input: Optional[Union[dict, list, str, int, float, bool]]
54+
id: Optional[str] = None
55+
input: Optional[Union[dict, list, str, int, float, bool]] = None
5456

5557

56-
class DefaultInput(BaseModel):
58+
@dataclass
59+
class DefaultInput:
5760
""" Represents a test input. """
5861
input: Dict[str, Any]
5962

6063

6164
# ------------------------------ Output Objects ------------------------------ #
62-
class JobOutput(BaseModel):
65+
@dataclass
66+
class JobOutput:
6367
''' Represents the output of a job. '''
6468
id: str
6569
status: str
66-
output: Optional[Union[dict, list, str, int, float, bool]]
67-
error: Optional[str]
70+
output: Optional[Union[dict, list, str, int, float, bool]] = None
71+
error: Optional[str] = None
6872

6973

70-
class StreamOutput(BaseModel):
74+
@dataclass
75+
class StreamOutput:
7176
""" Stream representation of a job. """
7277
id: str
7378
status: str = "IN_PROGRESS"
74-
stream: Optional[Union[dict, list, str, int, float, bool]]
75-
error: Optional[str]
79+
stream: Optional[Union[dict, list, str, int, float, bool]] = None
80+
error: Optional[str] = None
7681

7782

7883
# ---------------------------------------------------------------------------- #
@@ -191,6 +196,13 @@ async def _sim_runsync(self, job_input: DefaultInput) -> JobOutput:
191196
else:
192197
job_output = await run_job(self.config["handler"], job.__dict__)
193198

199+
if job_output.get('error', None):
200+
return jsonable_encoder({
201+
"id": job.id,
202+
"status": "FAILED",
203+
"error": job_output['error']
204+
})
205+
194206
return jsonable_encoder({
195207
"id": job.id,
196208
"status": "COMPLETED",
@@ -253,6 +265,13 @@ async def _sim_status(self, job_id: str) -> JobOutput:
253265

254266
job_list.remove_job(job.id)
255267

268+
if job_output.get('error', None):
269+
return jsonable_encoder({
270+
"id": job_id,
271+
"status": "FAILED",
272+
"error": job_output['error']
273+
})
274+
256275
return jsonable_encoder({
257276
"id": job_id,
258277
"status": "COMPLETED",

tests/test_serverless/test_modules/test_fastapi.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def setUp(self) -> None:
1919
self.handler = Mock()
2020
self.handler.return_value = {"result": "success"}
2121

22+
self.error_handler = Mock()
23+
self.error_handler.side_effect = Exception("test error")
24+
2225
def test_start_serverless_with_realtime(self):
2326
'''
2427
Tests the start_serverless() method with the realtime option.
@@ -139,6 +142,12 @@ def generator_handler(job):
139142
"output": [{"result": "success"}]
140143
}
141144

145+
# Test with error handler
146+
error_worker_api = rp_fastapi.WorkerAPI({"handler": self.error_handler})
147+
error_runsync_return = asyncio.run(
148+
error_worker_api._sim_runsync(default_input_object))
149+
assert "error" in error_runsync_return
150+
142151
loop.close()
143152

144153
@pytest.mark.asyncio
@@ -243,4 +252,11 @@ def generator_handler(job):
243252
"output": [{"result": "success"}]
244253
}
245254

255+
# Test with error handler
256+
error_worker_api = rp_fastapi.WorkerAPI({"handler": self.error_handler})
257+
asyncio.run(error_worker_api._sim_run(default_input_object))
258+
error_status_return = asyncio.run(
259+
error_worker_api._sim_status("test-123"))
260+
assert "error" in error_status_return
261+
246262
loop.close()

0 commit comments

Comments
 (0)