Skip to content

Commit 102592a

Browse files
Merge pull request #81 from runpod/fix_job_gen
fix this
2 parents 4a6342d + 2329ae8 commit 102592a

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

runpod/serverless/modules/rp_job.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,12 @@ async def run_job_generator(
155155
'''
156156
try:
157157
job_output = handler(job)
158-
for output_partial in job_output:
159-
yield {"output": output_partial}
158+
if inspect.isasyncgenfunction(handler):
159+
async for output_partial in job_output:
160+
yield {"output": output_partial}
161+
else:
162+
for output_partial in job_output:
163+
yield {"output": output_partial}
160164
except Exception as err: # pylint: disable=broad-except
161165
log.error(f'Error while running job {job["id"]}: {err}')
162166
yield {"error": f"handler: {str(err)} \ntraceback: {traceback.format_exc()}"}

tests/test_serverless/test_modules/test_job.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,24 +214,46 @@ async def test_job_with_exception(self):
214214
class TestRunJobGenerator(IsolatedAsyncioTestCase):
215215
''' Tests the run_job_generator function '''
216216

217-
def handler_success(self, job): # pylint: disable=unused-argument
217+
def handler_gen_success(self, job): # pylint: disable=unused-argument
218218
'''
219-
Test handler that returns a generator
219+
Test handler that returns a generator.
220+
'''
221+
yield "partial_output_1"
222+
yield "partial_output_2"
223+
224+
async def handler_async_gen_success(self, job): # pylint: disable=unused-argument
225+
'''
226+
Test handler that returns an async generator.
220227
'''
221228
yield "partial_output_1"
222229
yield "partial_output_2"
223230

224231
def handler_fail(self, job):
225232
'''
226-
Test handler that raises an exception
233+
Test handler that raises an exception.
227234
'''
228235
raise Exception("Test Exception") # pylint: disable=broad-exception-raised
229236

230237
async def test_run_job_generator_success(self):
231238
'''
232239
Tests the run_job_generator function with a successful generator
233240
'''
234-
handler = self.handler_success
241+
handler = self.handler_gen_success
242+
job = {"id": "123"}
243+
244+
with patch("runpod.serverless.modules.rp_job.log", new_callable=Mock) as mock_log:
245+
result = [i async for i in rp_job.run_job_generator(handler, job)]
246+
247+
assert result == [{"output": "partial_output_1"}, {"output": "partial_output_2"}]
248+
assert mock_log.error.call_count == 0
249+
assert mock_log.info.call_count == 1
250+
mock_log.info.assert_called_with('123 | Finished ')
251+
252+
async def test_run_job_generator_success_async(self):
253+
'''
254+
Tests the run_job_generator function with a successful generator
255+
'''
256+
handler = self.handler_async_gen_success
235257
job = {"id": "123"}
236258

237259
with patch("runpod.serverless.modules.rp_job.log", new_callable=Mock) as mock_log:

0 commit comments

Comments
 (0)