Skip to content

Commit 5645bb1

Browse files
Merge pull request #259 from runpod/rust-core-integration
Rust core integration
2 parents 31f0fda + fcf5de6 commit 5645bb1

File tree

9 files changed

+264
-6
lines changed

9 files changed

+264
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ __pycache__/
1616

1717
# C extensions
1818
*.so
19+
!sls_core.so
1920

2021
# Distribution / packaging
2122
.Python

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
# Change Log
22

3+
## Release 1.5.0 (12/28/23)
4+
5+
### Added
6+
7+
- Optional serverless core implementation, use with environment variable `RUNPOD_USE_CORE=True` or `RUNPOD_CORE_PATH=/path/to/core.so`
8+
9+
### Changed
10+
11+
- Reduced *await asyncio.sleep* calls to 0 to reduce execution time.
12+
13+
---
14+
315
## Release 1.4.2 (12/14/23)
416

517
### Fixed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,4 @@ dependencies = { file = ["requirements.txt"] }
7070

7171
# Used by pytest coverage
7272
[tool.coverage.run]
73-
omit = ["runpod/_version.py"]
73+
omit = ["runpod/_version.py", "runpod/serverless/core.py"]

runpod/serverless/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import argparse
1313
from typing import Dict, Any
1414

15+
from runpod.serverless import core
1516
from . import worker
1617
from .modules import rp_fastapi
1718
from .modules.rp_logger import RunPodLogger
@@ -125,7 +126,7 @@ def start(config: Dict[str, Any]):
125126
realtime_concurrency = _get_realtime_concurrency()
126127

127128
if config["rp_args"]["rp_serve_api"]:
128-
print("Starting API server.")
129+
log.info("Starting API server.")
129130
api_server = rp_fastapi.WorkerAPI(config)
130131

131132
api_server.start_uvicorn(
@@ -135,7 +136,7 @@ def start(config: Dict[str, Any]):
135136
)
136137

137138
elif realtime_port:
138-
print("Starting API server for realtime.")
139+
log.info("Starting API server for realtime.")
139140
api_server = rp_fastapi.WorkerAPI(config)
140141

141142
api_server.start_uvicorn(
@@ -144,5 +145,11 @@ def start(config: Dict[str, Any]):
144145
api_concurrency=realtime_concurrency
145146
)
146147

148+
# --------------------------------- SLS-Core --------------------------------- #
149+
elif os.environ.get("RUNPOD_USE_CORE", None) or os.environ.get("RUNPOD_CORE_PATH", None):
150+
log.info("Starting worker with SLS-Core.")
151+
core.main(config)
152+
153+
# --------------------------------- Standard --------------------------------- #
147154
else:
148155
worker.main(config)

runpod/serverless/core.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
""" Core functionality for the runpod serverless worker. """
2+
3+
import ctypes
4+
import inspect
5+
import json
6+
import os
7+
import pathlib
8+
import asyncio
9+
from ctypes import CDLL, byref, c_char_p, c_int
10+
from typing import Any, Callable, List, Dict, Optional
11+
12+
from runpod.serverless.modules.rp_logger import RunPodLogger
13+
14+
15+
log = RunPodLogger()
16+
17+
18+
class CGetJobResult(ctypes.Structure): # pylint: disable=too-few-public-methods
19+
"""
20+
result of _runpod_sls_get_jobs.
21+
## fields
22+
- `res_len` the number bytes were written to the `dst_buf` passed to _runpod_sls_get_jobs.
23+
- `status_code` tells you what happened.
24+
see CGetJobResult.status_code for more information.
25+
"""
26+
27+
_fields_ = [("status_code", ctypes.c_int), ("res_len", ctypes.c_int)]
28+
29+
def __str__(self) -> str:
30+
return f"CGetJobResult(res_len={self.res_len}, status_code={self.status_code})"
31+
32+
33+
class Hook: # pylint: disable=too-many-instance-attributes
34+
""" Singleton class for interacting with sls_core.so"""
35+
36+
_instance = None
37+
38+
# C function pointers
39+
_get_jobs: Callable = None
40+
_progress_update: Callable = None
41+
_stream_output: Callable = None
42+
_post_output: Callable = None
43+
_finish_stream: Callable = None
44+
45+
def __new__(cls):
46+
if Hook._instance is None:
47+
Hook._instance = object.__new__(cls)
48+
Hook._initialized = False
49+
return Hook._instance
50+
51+
def __init__(self, rust_so_path: Optional[str] = None) -> None:
52+
if self._initialized:
53+
return
54+
55+
if rust_so_path is None:
56+
default_path = os.path.join(
57+
pathlib.Path(__file__).parent.absolute(), "sls_core.so"
58+
)
59+
self.rust_so_path = os.environ.get("RUNPOD_SLS_CORE_PATH", str(default_path))
60+
else:
61+
self.rust_so_path = rust_so_path
62+
63+
rust_library = CDLL(self.rust_so_path)
64+
buffer = ctypes.create_string_buffer(1024) # 1 KiB
65+
num_bytes = rust_library._runpod_sls_crate_version(byref(buffer), c_int(len(buffer)))
66+
67+
self.rust_crate_version = buffer.raw[:num_bytes].decode("utf-8")
68+
69+
# Get Jobs
70+
self._get_jobs = rust_library._runpod_sls_get_jobs
71+
self._get_jobs.restype = CGetJobResult
72+
73+
# Progress Update
74+
self._progress_update = rust_library._runpod_sls_progress_update
75+
self._progress_update.argtypes = [
76+
c_char_p, c_int, # id_ptr, id_len
77+
c_char_p, c_int # json_ptr, json_len
78+
]
79+
self._progress_update.restype = c_int # 1 if success, 0 if failure
80+
81+
# Stream Output
82+
self._stream_output = rust_library._runpod_sls_stream_output
83+
self._stream_output.argtypes = [
84+
c_char_p, c_int, # id_ptr, id_len
85+
c_char_p, c_int, # json_ptr, json_len
86+
]
87+
self._stream_output.restype = c_int # 1 if success, 0 if failure
88+
89+
# Post Output
90+
self._post_output = rust_library._runpod_sls_post_output
91+
self._post_output.argtypes = [
92+
c_char_p, c_int, # id_ptr, id_len
93+
c_char_p, c_int, # json_ptr, json_len
94+
]
95+
self._post_output.restype = c_int # 1 if success, 0 if failure
96+
97+
# Finish Stream
98+
self._finish_stream = rust_library._runpod_sls_finish_stream
99+
self._finish_stream.argtypes = [c_char_p, c_int] # id_ptr, id_len
100+
self._finish_stream.restype = c_int # 1 if success, 0 if failure
101+
102+
rust_library._runpod_sls_crate_version.restype = c_int
103+
104+
rust_library._runpod_sls_init.argtypes = []
105+
rust_library._runpod_sls_init.restype = c_int
106+
rust_library._runpod_sls_init()
107+
108+
self._initialized = True
109+
110+
def _json_serialize_job_data(self, job_data: Any) -> bytes:
111+
return json.dumps(job_data, ensure_ascii=False).encode("utf-8")
112+
113+
def get_jobs(self, max_concurrency: int, max_jobs: int) -> List[Dict[str, Any]]:
114+
"""Get a job or jobs from the queue. The jobs are returned as a list of Job objects."""
115+
buffer = ctypes.create_string_buffer(1024 * 1024 * 20) # 20MB buffer to store jobs in
116+
destination_length = len(buffer.raw)
117+
result: CGetJobResult = self._get_jobs(
118+
c_int(max_concurrency), c_int(max_jobs),
119+
byref(buffer), c_int(destination_length)
120+
)
121+
if result.status_code == 1: # success! the job was stored bytes 0..res_len of buf.raw
122+
return list(json.loads(buffer.raw[: result.res_len].decode("utf-8")))
123+
124+
if result.status_code not in [0, 1]:
125+
raise RuntimeError(f"get_jobs failed with status code {result.status_code}")
126+
127+
return [] # Status code 0, still waiting for jobs
128+
129+
def progress_update(self, job_id: str, json_data: bytes) -> bool:
130+
"""
131+
send a progress update to AI-API.
132+
"""
133+
id_bytes = job_id.encode("utf-8")
134+
return bool(self._progress_update(
135+
c_char_p(id_bytes), c_int(len(id_bytes)),
136+
c_char_p(json_data), c_int(len(json_data))
137+
))
138+
139+
def stream_output(self, job_id: str, job_output: bytes) -> bool:
140+
"""
141+
send part of a streaming result to AI-API.
142+
"""
143+
json_data = self._json_serialize_job_data(job_output)
144+
id_bytes = job_id.encode("utf-8")
145+
return bool(self._stream_output(
146+
c_char_p(id_bytes), c_int(len(id_bytes)),
147+
c_char_p(json_data), c_int(len(json_data))
148+
))
149+
150+
def post_output(self, job_id: str, job_output: bytes) -> bool:
151+
"""
152+
send the result of a job to AI-API.
153+
Returns True if the task was successfully stored, False otherwise.
154+
"""
155+
json_data = self._json_serialize_job_data(job_output)
156+
id_bytes = job_id.encode("utf-8")
157+
return bool(self._post_output(
158+
c_char_p(id_bytes), c_int(len(id_bytes)),
159+
c_char_p(json_data), c_int(len(json_data))
160+
))
161+
162+
def finish_stream(self, job_id: str) -> bool:
163+
"""
164+
tell the SLS queue that the result of a streaming job is complete.
165+
"""
166+
id_bytes = job_id.encode("utf-8")
167+
return bool(self._finish_stream(
168+
c_char_p(id_bytes), c_int(len(id_bytes))
169+
))
170+
171+
172+
# -------------------------------- Process Job ------------------------------- #
173+
def _process_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
174+
""" Process a single job. """
175+
hook = Hook()
176+
177+
try:
178+
result = handler(job)
179+
except Exception as err:
180+
raise RuntimeError(
181+
f"run {job['id']}: user code raised an {type(err).__name__}") from err
182+
183+
if inspect.isgeneratorfunction(handler):
184+
for part in result:
185+
hook.stream_output(job['id'], part)
186+
187+
hook.finish_stream(job['id'])
188+
189+
else:
190+
hook.post_output(job['id'], result)
191+
192+
193+
# -------------------------------- Run Worker -------------------------------- #
194+
async def run(config: Dict[str, Any]) -> None:
195+
""" Run the worker.
196+
197+
Args:
198+
config: A dictionary containing the following keys:
199+
handler: A function that takes a job and returns a result.
200+
"""
201+
handler = config['handler']
202+
max_concurrency = config.get('max_concurrency', 4)
203+
max_jobs = config.get('max_jobs', 4)
204+
205+
hook = Hook()
206+
207+
while True:
208+
jobs = hook.get_jobs(max_concurrency, max_jobs)
209+
210+
if len(jobs) == 0:
211+
continue
212+
213+
for job in jobs:
214+
asyncio.create_task(_process_job(handler, job))
215+
await asyncio.sleep(0)
216+
217+
await asyncio.sleep(0)
218+
219+
220+
def main(config: Dict[str, Any]) -> None:
221+
"""Run the worker in an asyncio event loop."""
222+
try:
223+
work_loop = asyncio.new_event_loop()
224+
asyncio.ensure_future(run(config), loop=work_loop)
225+
work_loop.run_forever()
226+
finally:
227+
work_loop.close()

runpod/serverless/modules/rp_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ async def get_job(session: ClientSession, retry=True) -> Optional[Dict[str, Any]
106106
if retry is False:
107107
break
108108

109-
await asyncio.sleep(1)
109+
await asyncio.sleep(0)
110110
else:
111111
job_list.add_job(next_job["id"])
112112
log.debug("Request ID added.", next_job['id'])

runpod/serverless/modules/rp_scale.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ async def get_jobs(self, session):
8181
if job:
8282
yield job
8383

84-
await asyncio.sleep(1)
85-
84+
await asyncio.sleep(0)
8685

8786
log.debug(f"Concurrency set to: {self.current_concurrency}")

runpod/serverless/sls_core.so

9.71 MB
Binary file not shown.

tests/test_serverless/test_worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,15 @@ def mock_is_alive():
511511
# 5 calls with actual jobs
512512
assert mock_run_job.call_count == 5
513513
assert mock_send_result.call_count == 5
514+
515+
# Test with sls-core
516+
async def test_run_worker_with_sls_core(self):
517+
'''
518+
Test run_worker with sls-core.
519+
'''
520+
os.environ["RUNPOD_USE_CORE"] = "true"
521+
522+
with patch("runpod.serverless.core.main") as mock_main:
523+
runpod.serverless.start(self.config)
524+
525+
assert mock_main.called

0 commit comments

Comments
 (0)